1010import vigra
1111import torch
1212import z5py
13+ import json
1314
1415from elf .wrapper import ThresholdWrapper , SimpleTransformationWrapper
1516from elf .wrapper .resized_volume import ResizedVolume
1819from torch_em .util .prediction import predict_with_halo
1920from tqdm import tqdm
2021
22+ """
23+ Prediction using distance U-Net.
24+ Parallelization using multiple GPUs is currently only possible by calling functions directly.
25+ Functions for the parallelization end with '_slurm' and divide the process into preprocessing, prediction, and segmentation.
26+ """
2127
2228class SelectChannel (SimpleTransformationWrapper ):
2329 def __init__ (self , volume , channel ):
@@ -37,13 +43,13 @@ def ndim(self):
3743 return self ._volume .ndim - 1
3844
3945
40- def prediction_impl (input_path , input_key , output_folder , model_path , scale , block_shape , halo ):
46+ def prediction_impl (input_path , input_key , output_folder , model_path , scale , block_shape , halo , prediction_instances = 1 , slurm_task_id = 0 , mean = None , std = None ):
4147 with warnings .catch_warnings ():
4248 warnings .simplefilter ("ignore" )
4349 if os .path .isdir (model_path ):
4450 model = load_model (model_path )
4551 else :
46- model = torch .load (model_path )
52+ model = torch .load (model_path , weights_only = False )
4753
4854 mask_path = os .path .join (output_folder , "mask.zarr" )
4955 image_mask = z5py .File (mask_path , "r" )["mask" ]
@@ -66,23 +72,25 @@ def prediction_impl(input_path, input_key, output_folder, model_path, scale, blo
6672 image_mask = ResizedVolume (image_mask , new_shape , order = 0 )
6773
6874 have_cuda = torch .cuda .is_available ()
75+
6976 if block_shape is None :
70- block_shape = tuple ([ 2 * ch for ch in input_ . chunks ] ) if have_cuda else input_ .chunks
77+ block_shape = ( 128 , 128 , 128 ) if have_cuda else input_ .chunks
7178 if halo is None :
72- halo = (16 , 64 , 64 ) if have_cuda else ( 16 , 32 , 32 )
79+ halo = (16 , 32 , 32 )
7380 if have_cuda :
7481 print ("Predict with GPU" )
7582 gpu_ids = [0 ]
7683 else :
7784 print ("Predict with CPU" )
7885 gpu_ids = ["cpu" ]
7986
80- # Compute the global mean and standard deviation.
81- n_threads = min (16 , mp .cpu_count ())
82- mean , std = parallel .mean_and_std (
83- input_ , block_shape = block_shape , n_threads = n_threads , verbose = True ,
84- mask = image_mask
85- )
87+ if mean is None or std is None :
88+ # Compute the global mean and standard deviation.
89+ n_threads = min (16 , mp .cpu_count ())
90+ mean , std = parallel .mean_and_std (
91+ input_ , block_shape = block_shape , n_threads = n_threads , verbose = True ,
92+ mask = image_mask
93+ )
8694 print ("Mean and standard deviation computed for the full volume:" )
8795 print (mean , std )
8896
@@ -98,6 +106,17 @@ def postprocess(x):
98106 x [1 ] = vigra .filters .gaussianSmoothing (x [1 ], sigma = 2.0 )
99107 return x
100108
109+ shape = input_ .shape
110+ ndim = len (shape )
111+
112+ blocking = nt .blocking ([0 ] * ndim , shape , block_shape )
113+ n_blocks = blocking .numberOfBlocks
114+ if prediction_instances != 1 :
115+ iteration_ids = [x .tolist () for x in np .array_split (list (range (n_blocks )), prediction_instances )]
116+ slurm_iteration = iteration_ids [slurm_task_id ]
117+ else :
118+ slurm_iteration = list (range (n_blocks ))
119+
101120 output_path = os .path .join (output_folder , "predictions.zarr" )
102121 with open_file (output_path , "a" ) as f :
103122 output = f .require_dataset (
@@ -113,6 +132,7 @@ def postprocess(x):
113132 gpu_ids = gpu_ids , block_shape = block_shape , halo = halo ,
114133 output = output , preprocess = preprocess , postprocess = postprocess ,
115134 mask = image_mask ,
135+ iter_list = slurm_iteration ,
116136 )
117137
118138 return original_shape
@@ -223,6 +243,30 @@ def write_block(block_id):
223243 tp .map (write_block , range (blocking .numberOfBlocks ))
224244
225245
246+ def calc_mean_and_std (input_path , input_key , output_folder ):
247+ """
248+ Calculate mean and standard deviation of full volume.
249+ Parameters are saved in 'mean_std.json' within the output folder.
250+ """
251+ json_file = os .path .join (output_folder , "mean_std.json" )
252+ mask_path = os .path .join (output_folder , "mask.zarr" )
253+ image_mask = z5py .File (mask_path , "r" )["mask" ]
254+
255+ if input_key is None :
256+ input_ = imageio .imread (input_path )
257+ else :
258+ input_ = open_file (input_path , "r" )[input_key ]
259+
260+ # Compute the global mean and standard deviation.
261+ n_threads = min (16 , mp .cpu_count ())
262+ mean , std = parallel .mean_and_std (
263+ input_ , block_shape = tuple ([2 * i for i in input_ .chunks ]), n_threads = n_threads , verbose = True ,
264+ mask = image_mask
265+ )
266+ ddict = {"mean" :mean , "std" :std }
267+ with open (json_file , "w" ) as f :
268+ json .dump (ddict , f )
269+
226270def run_unet_prediction (
227271 input_path , input_key ,
228272 output_folder , model_path ,
@@ -239,3 +283,56 @@ def run_unet_prediction(
239283
240284 pmap_out = os .path .join (output_folder , "predictions.zarr" )
241285 segmentation_impl (pmap_out , output_folder , min_size = min_size , original_shape = original_shape )
286+
287+ #---Workflow for parallel prediction using slurm---
288+
289+ def run_unet_prediction_preprocess_slurm (
290+ input_path , input_key , output_folder ,
291+ ):
292+ """
293+ Pre-processing for the parallel prediction with U-Net models.
294+ Masks are stored in mask.zarr in the output folder.
295+ The mean and standard deviation are precomputed for later usage during prediction
296+ and stored in a JSON file within the output folder as mean_std.json
297+ """
298+ find_mask (input_path , input_key , output_folder )
299+ calc_mean_and_std (input_path , input_key , output_folder )
300+
301+ def run_unet_prediction_slurm (
302+ input_path , input_key , output_folder , model_path ,
303+ scale = None ,
304+ block_shape = None , halo = None , prediction_instances = 1 ,
305+ ):
306+ os .makedirs (output_folder , exist_ok = True )
307+ prediction_instances = int (prediction_instances )
308+ slurm_task_id = os .environ .get ("SLURM_ARRAY_TASK_ID" )
309+
310+ if slurm_task_id is not None :
311+ slurm_task_id = int (slurm_task_id )
312+ else :
313+ raise ValueError ("The SLURM_ARRAY_TASK_ID is not set. Ensure that you are using the '-a' option with SBATCH." )
314+
315+ if not os .path .isdir (os .path .join (output_folder , "mask.zarr" )):
316+ find_mask (input_path , input_key , output_folder )
317+
318+ # get pre-computed mean and standard deviation of full volume from JSON file
319+ if os .path .isfile (os .path .join (output_folder , "mean_std.json" )):
320+ with open (os .path .join (output_folder , "mean_std.json" )) as f :
321+ d = json .load (f )
322+ mean = float (d ["mean" ])
323+ std = float (d ["std" ])
324+ else :
325+ mean = None
326+ std = None
327+
328+ original_shape = prediction_impl (
329+ input_path , input_key , output_folder , model_path , scale , block_shape , halo ,
330+ prediction_instances = prediction_instances , slurm_task_id = slurm_task_id ,
331+ mean = mean , std = std ,
332+ )
333+
334+ # does NOT need GPU, FIXME: only run on CPU
335+ def run_unet_segmentation_slurm (output_folder , min_size ):
336+ min_size = int (min_size )
337+ pmap_out = os .path .join (output_folder , "predictions.zarr" )
338+ segmentation_impl (pmap_out , output_folder , min_size = min_size )
0 commit comments