1010import vigra
1111import torch
1212import z5py
13+ import json
1314
1415from elf .wrapper import ThresholdWrapper , SimpleTransformationWrapper
1516from elf .wrapper .resized_volume import ResizedVolume
@@ -37,7 +38,7 @@ def ndim(self):
3738 return self ._volume .ndim - 1
3839
3940
40- def prediction_impl (input_path , input_key , output_folder , model_path , scale , block_shape , halo , prediction_instances = 1 , slurm_task_id = 0 ):
41+ def prediction_impl (input_path , input_key , output_folder , model_path , scale , block_shape , halo , prediction_instances = 1 , slurm_task_id = 0 , mean = None , std = None ):
4142 with warnings .catch_warnings ():
4243 warnings .simplefilter ("ignore" )
4344 if os .path .isdir (model_path ):
@@ -79,12 +80,13 @@ def prediction_impl(input_path, input_key, output_folder, model_path, scale, blo
7980 if halo is None :
8081 halo = (16 , 32 , 32 )
8182
82- # Compute the global mean and standard deviation.
83- n_threads = min (16 , mp .cpu_count ())
84- mean , std = parallel .mean_and_std (
85- input_ , block_shape = tuple ([2 * i for i in input_ .chunks ]), n_threads = n_threads , verbose = True ,
86- mask = image_mask
87- )
83+ if None == mean or None == std :
84+ # Compute the global mean and standard deviation.
85+ n_threads = min (16 , mp .cpu_count ())
86+ mean , std = parallel .mean_and_std (
87+ input_ , block_shape = tuple ([2 * i for i in input_ .chunks ]), n_threads = n_threads , verbose = True ,
88+ mask = image_mask
89+ )
8890 print ("Mean and standard deviation computed for the full volume:" )
8991 print (mean , std )
9092
@@ -238,6 +240,30 @@ def write_block(block_id):
238240 tp .map (write_block , range (blocking .numberOfBlocks ))
239241
240242
243+ def calc_mean_and_std (input_path , input_key , output_folder ):
244+ """
245+ Calculate mean and standard deviation of full volume.
246+ Parameters are saved in 'mean_std.json' within the output folder.
247+ """
248+ json_file = os .path .join (output_folder , "mean_std.json" )
249+ mask_path = os .path .join (output_folder , "mask.zarr" )
250+ image_mask = z5py .File (mask_path , "r" )["mask" ]
251+
252+ if input_key is None :
253+ input_ = imageio .imread (input_path )
254+ else :
255+ input_ = open_file (input_path , "r" )[input_key ]
256+
257+ # Compute the global mean and standard deviation.
258+ n_threads = min (16 , mp .cpu_count ())
259+ mean , std = parallel .mean_and_std (
260+ input_ , block_shape = tuple ([2 * i for i in input_ .chunks ]), n_threads = n_threads , verbose = True ,
261+ mask = image_mask
262+ )
263+ ddict = {"mean" :str (mean ), "std" : str (std )}
264+ with open (json_file , "w" ) as f :
265+ json .dump (ddict , f )
266+
241267def run_unet_prediction (
242268 input_path , input_key ,
243269 output_folder , model_path ,
@@ -255,6 +281,18 @@ def run_unet_prediction(
255281 pmap_out = os .path .join (output_folder , "predictions.zarr" )
256282 segmentation_impl (pmap_out , output_folder , min_size = min_size , original_shape = original_shape )
257283
284+ def run_unet_prediction_slurm_preprocess (
285+ input_path , input_key , output_folder ,
286+ ):
287+ """
288+ Pre-processing for the parallel prediction with U-Net models.
289+ Masks are stored in mask.zarr in the output folder.
290+ The mean and standard deviation are precomputed for later usage during prediction
291+ and stored in a JSON file within the output folder as mean_std.json
292+ """
293+ find_mask (input_path , input_key , output_folder )
294+ calc_mean_and_std (input_path , input_key , output_folder )
295+
258296def run_unet_prediction_slurm (
259297 input_path , input_key , output_folder , model_path ,
260298 scale = None ,
@@ -269,10 +307,21 @@ def run_unet_prediction_slurm(
269307 else :
270308 raise ValueError ("The SLURM_ARRAY_TASK_ID is not set. Ensure that you are using the '-a' option with SBATCH." )
271309
272- find_mask (input_path , input_key , output_folder )
310+ if not os .path .isdir (os .path .join (output_folder , "mask.zarr" )):
311+ find_mask (input_path , input_key , output_folder )
312+
313+ # get pre-computed mean and standard deviation of full volume from JSON file
314+ if os .path .isfile (os .path .join (output_folder , "mean_std.json" )):
315+ with open (os .path .join (output_folder , "mean_std.json" )) as f :
316+ d = json .load (f )
317+ mean = float (d ["mean" ])
318+ std = float (d ["std" ])
319+ else :
320+ mean = None
321+ std = None
273322
274323 original_shape = prediction_impl (
275- input_path , input_key , output_folder , model_path , scale , block_shape , halo , prediction_instances , slurm_task_id
324+ input_path , input_key , output_folder , model_path , scale , block_shape , halo , prediction_instances , slurm_task_id , mean = mean , std = std
276325 )
277326
278327# does NOT need GPU, FIXME: only run on CPU
0 commit comments