@@ -37,13 +37,13 @@ def ndim(self):
3737        return  self ._volume .ndim  -  1 
3838
3939
40- def  prediction_impl (input_path , input_key , output_folder , model_path , scale , block_shape , halo ):
40+ def  prediction_impl (input_path , input_key , output_folder , model_path , scale , block_shape , halo ,  prediction_instances = 1 ,  slurm_task_id = 0 ):
4141    with  warnings .catch_warnings ():
4242        warnings .simplefilter ("ignore" )
4343        if  os .path .isdir (model_path ):
4444            model  =  load_model (model_path )
4545        else :
46-             model  =  torch .load (model_path )
46+             model  =  torch .load (model_path ,  weights_only = False )
4747
4848    mask_path  =  os .path .join (output_folder , "mask.zarr" )
4949    image_mask  =  z5py .File (mask_path , "r" )["mask" ]
@@ -65,22 +65,24 @@ def prediction_impl(input_path, input_key, output_folder, model_path, scale, blo
6565        input_  =  ResizedVolume (input_ , shape = new_shape , order = 3 )
6666        image_mask  =  ResizedVolume (image_mask , new_shape , order = 0 )
6767
68+     chunks   =  (128 , 128 , 128 )
69+     block_shape  =  chunks 
70+ 
6871    have_cuda  =  torch .cuda .is_available ()
69-     if  block_shape  is  None :
70-         block_shape  =  tuple ([2  *  ch  for  ch  in  input_ .chunks ]) if  have_cuda  else  input_ .chunks 
71-     if  halo  is  None :
72-         halo  =  (16 , 64 , 64 ) if  have_cuda  else  (16 , 32 , 32 )
72+     assert  have_cuda 
7373    if  have_cuda :
7474        print ("Predict with GPU" )
7575        gpu_ids  =  [0 ]
7676    else :
7777        print ("Predict with CPU" )
7878        gpu_ids  =  ["cpu" ]
79+     if  halo  is  None :
80+         halo  =  (16 , 32 , 32 )
7981
8082    # Compute the global mean and standard deviation. 
81-     n_threads  =  min (16 , mp .cpu_count ())
83+     n_threads  =  min (2 , mp .cpu_count ())
8284    mean , std  =  parallel .mean_and_std (
83-         input_ , block_shape = block_shape , n_threads = n_threads , verbose = True ,
85+         input_ , block_shape = tuple ([ 2 *   i   for   i   in   input_ . chunks ]) , n_threads = n_threads , verbose = True ,
8486        mask = image_mask 
8587    )
8688    print ("Mean and standard deviation computed for the full volume:" )
@@ -98,12 +100,24 @@ def postprocess(x):
98100        x [1 ] =  vigra .filters .gaussianSmoothing (x [1 ], sigma = 2.0 )
99101        return  x 
100102
103+     shape  =  input_ .shape 
104+     ndim  =  len (shape )
105+ 
106+     blocking  =  nt .blocking ([0 ] *  ndim , shape , block_shape )
107+     n_blocks  =  blocking .numberOfBlocks 
108+     iteration_ids  =  []
109+     if  1  !=  prediction_instances :
110+         iteration_ids  =  [x .tolist () for  x  in  np .array_split (list (range (n_blocks )), prediction_instances )]
111+         slurm_iteration  =  iteration_ids [slurm_task_id ]
112+     else :
113+         slurm_iteration  =  list (range (n_blocks ))
114+ 
101115    output_path  =  os .path .join (output_folder , "predictions.zarr" )
102116    with  open_file (output_path , "a" ) as  f :
103117        output  =  f .require_dataset (
104118            "prediction" ,
105119            shape = (3 ,) +  input_ .shape ,
106-             chunks = (1 ,) +  block_shape ,
120+             chunks = (1 ,) +  chunks ,
107121            compression = "gzip" ,
108122            dtype = "float32" ,
109123        )
@@ -113,6 +127,7 @@ def postprocess(x):
113127            gpu_ids = gpu_ids , block_shape = block_shape , halo = halo ,
114128            output = output , preprocess = preprocess , postprocess = postprocess ,
115129            mask = image_mask ,
130+             iter_list = slurm_iteration ,
116131        )
117132
118133    return  original_shape 
@@ -228,14 +243,45 @@ def run_unet_prediction(
228243    output_folder , model_path ,
229244    min_size , scale = None ,
230245    block_shape = None , halo = None ,
246+     prediction_instances = 1 ,
247+ ):
248+     if  prediction_instances  >  1 :
249+         run_unet_prediction_slurm (
250+             input_path , input_key , output_folder , model_path ,
251+             scale = scale , block_shape = block_shape , halo = halo ,
252+             prediction_instances = prediction_instances ,
253+         )
254+     else :
255+         os .makedirs (output_folder , exist_ok = True )
256+ 
257+         find_mask (input_path , input_key , output_folder )
258+ 
259+         original_shape  =  prediction_impl (
260+             input_path , input_key , output_folder , model_path , scale , block_shape , halo 
261+         )
262+ 
263+         pmap_out  =  os .path .join (output_folder , "predictions.zarr" )
264+         segmentation_impl (pmap_out , output_folder , min_size = min_size , original_shape = original_shape )
265+ 
266+ def  run_unet_prediction_slurm (
267+     input_path , input_key , output_folder , model_path ,
268+     scale = None ,
269+     block_shape = None , halo = None , prediction_instances = 1 ,
231270):
232271    os .makedirs (output_folder , exist_ok = True )
272+     prediction_instances  =  int (prediction_instances )
273+     slurm_task_id  =  os .environ .get ("SLURM_ARRAY_TASK_ID" )
274+     if  slurm_task_id  is  not   None :
275+         slurm_task_id  =  int (slurm_task_id )
233276
234277    find_mask (input_path , input_key , output_folder )
235278
236279    original_shape  =  prediction_impl (
237-         input_path , input_key , output_folder , model_path , scale , block_shape , halo 
280+         input_path , input_key , output_folder , model_path , scale , block_shape , halo ,  prediction_instances ,  slurm_task_id 
238281    )
239282
283+ # does NOT need GPU, FIXME: only run on CPU 
284+ def  run_unet_segmentation_slurm (output_folder , min_size ):
285+     min_size  =  int (min_size )
240286    pmap_out  =  os .path .join (output_folder , "predictions.zarr" )
241-     segmentation_impl (pmap_out , output_folder , min_size = min_size ,  original_shape = original_shape )
287+     segmentation_impl (pmap_out , output_folder , min_size = min_size )
0 commit comments