Skip to content

Commit e185ee1

Browse files
committed
Calculation of mean and standard deviation as preprocessing
1 parent 12799b1 commit e185ee1

File tree

1 file changed

+58
-9
lines changed

1 file changed

+58
-9
lines changed

flamingo_tools/segmentation/unet_prediction.py

Lines changed: 58 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import vigra
1111
import torch
1212
import z5py
13+
import json
1314

1415
from elf.wrapper import ThresholdWrapper, SimpleTransformationWrapper
1516
from elf.wrapper.resized_volume import ResizedVolume
@@ -37,7 +38,7 @@ def ndim(self):
3738
return self._volume.ndim - 1
3839

3940

40-
def prediction_impl(input_path, input_key, output_folder, model_path, scale, block_shape, halo, prediction_instances=1, slurm_task_id=0):
41+
def prediction_impl(input_path, input_key, output_folder, model_path, scale, block_shape, halo, prediction_instances=1, slurm_task_id=0, mean=None, std=None):
4142
with warnings.catch_warnings():
4243
warnings.simplefilter("ignore")
4344
if os.path.isdir(model_path):
@@ -79,12 +80,13 @@ def prediction_impl(input_path, input_key, output_folder, model_path, scale, blo
7980
if halo is None:
8081
halo = (16, 32, 32)
8182

82-
# Compute the global mean and standard deviation.
83-
n_threads = min(16, mp.cpu_count())
84-
mean, std = parallel.mean_and_std(
85-
input_, block_shape=tuple([2* i for i in input_.chunks]), n_threads=n_threads, verbose=True,
86-
mask=image_mask
87-
)
83+
if None == mean or None == std:
84+
# Compute the global mean and standard deviation.
85+
n_threads = min(16, mp.cpu_count())
86+
mean, std = parallel.mean_and_std(
87+
input_, block_shape=tuple([2* i for i in input_.chunks]), n_threads=n_threads, verbose=True,
88+
mask=image_mask
89+
)
8890
print("Mean and standard deviation computed for the full volume:")
8991
print(mean, std)
9092

@@ -238,6 +240,30 @@ def write_block(block_id):
238240
tp.map(write_block, range(blocking.numberOfBlocks))
239241

240242

243+
def calc_mean_and_std(input_path, input_key, output_folder):
244+
"""
245+
Calculate mean and standard deviation of full volume.
246+
Parameters are saved in 'mean_std.json' within the output folder.
247+
"""
248+
json_file = os.path.join(output_folder, "mean_std.json")
249+
mask_path = os.path.join(output_folder, "mask.zarr")
250+
image_mask = z5py.File(mask_path, "r")["mask"]
251+
252+
if input_key is None:
253+
input_ = imageio.imread(input_path)
254+
else:
255+
input_ = open_file(input_path, "r")[input_key]
256+
257+
# Compute the global mean and standard deviation.
258+
n_threads = min(16, mp.cpu_count())
259+
mean, std = parallel.mean_and_std(
260+
input_, block_shape=tuple([2* i for i in input_.chunks]), n_threads=n_threads, verbose=True,
261+
mask=image_mask
262+
)
263+
ddict = {"mean":str(mean), "std": str(std)}
264+
with open(json_file, "w") as f:
265+
json.dump(ddict, f)
266+
241267
def run_unet_prediction(
242268
input_path, input_key,
243269
output_folder, model_path,
@@ -255,6 +281,18 @@ def run_unet_prediction(
255281
pmap_out = os.path.join(output_folder, "predictions.zarr")
256282
segmentation_impl(pmap_out, output_folder, min_size=min_size, original_shape=original_shape)
257283

284+
def run_unet_prediction_slurm_preprocess(
285+
input_path, input_key, output_folder,
286+
):
287+
"""
288+
Pre-processing for the parallel prediction with U-Net models.
289+
Masks are stored in mask.zarr in the output folder.
290+
The mean and standard deviation are precomputed for later usage during prediction
291+
and stored in a JSON file within the output folder as mean_std.json
292+
"""
293+
find_mask(input_path, input_key, output_folder)
294+
calc_mean_and_std(input_path, input_key, output_folder)
295+
258296
def run_unet_prediction_slurm(
259297
input_path, input_key, output_folder, model_path,
260298
scale=None,
@@ -269,10 +307,21 @@ def run_unet_prediction_slurm(
269307
else:
270308
raise ValueError("The SLURM_ARRAY_TASK_ID is not set. Ensure that you are using the '-a' option with SBATCH.")
271309

272-
find_mask(input_path, input_key, output_folder)
310+
if not os.path.isdir(os.path.join(output_folder, "mask.zarr")):
311+
find_mask(input_path, input_key, output_folder)
312+
313+
# get pre-computed mean and standard deviation of full volume from JSON file
314+
if os.path.isfile(os.path.join(output_folder, "mean_std.json")):
315+
with open(os.path.join(output_folder, "mean_std.json")) as f:
316+
d = json.load(f)
317+
mean = float(d["mean"])
318+
std = float(d["std"])
319+
else:
320+
mean = None
321+
std = None
273322

274323
original_shape = prediction_impl(
275-
input_path, input_key, output_folder, model_path, scale, block_shape, halo, prediction_instances, slurm_task_id
324+
input_path, input_key, output_folder, model_path, scale, block_shape, halo, prediction_instances, slurm_task_id, mean=mean, std=std
276325
)
277326

278327
# does NOT need GPU, FIXME: only run on CPU

0 commit comments

Comments
 (0)