@@ -80,7 +80,7 @@ def prediction_impl(input_path, input_key, output_folder, model_path, scale, blo
8080        halo  =  (16 , 32 , 32 )
8181
8282    # Compute the global mean and standard deviation. 
83-     n_threads  =  min (2 , mp .cpu_count ())
83+     n_threads  =  min (16 , mp .cpu_count ())
8484    mean , std  =  parallel .mean_and_std (
8585        input_ , block_shape = tuple ([2 *  i  for  i  in  input_ .chunks ]), n_threads = n_threads , verbose = True ,
8686        mask = image_mask 
@@ -243,25 +243,17 @@ def run_unet_prediction(
243243    output_folder , model_path ,
244244    min_size , scale = None ,
245245    block_shape = None , halo = None ,
246-     prediction_instances = 1 ,
247246):
248-     if  prediction_instances  >  1 :
249-         run_unet_prediction_slurm (
250-             input_path , input_key , output_folder , model_path ,
251-             scale = scale , block_shape = block_shape , halo = halo ,
252-             prediction_instances = prediction_instances ,
253-         )
254-     else :
255-         os .makedirs (output_folder , exist_ok = True )
247+     os .makedirs (output_folder , exist_ok = True )
256248
257-          find_mask (input_path , input_key , output_folder )
249+     find_mask (input_path , input_key , output_folder )
258250
259-          original_shape  =  prediction_impl (
260-              input_path , input_key , output_folder , model_path , scale , block_shape , halo 
261-          )
251+     original_shape  =  prediction_impl (
252+         input_path , input_key , output_folder , model_path , scale , block_shape , halo 
253+     )
262254
263-          pmap_out  =  os .path .join (output_folder , "predictions.zarr" )
264-          segmentation_impl (pmap_out , output_folder , min_size = min_size , original_shape = original_shape )
255+     pmap_out  =  os .path .join (output_folder , "predictions.zarr" )
256+     segmentation_impl (pmap_out , output_folder , min_size = min_size , original_shape = original_shape )
265257
266258def  run_unet_prediction_slurm (
267259    input_path , input_key , output_folder , model_path ,
@@ -271,8 +263,11 @@ def run_unet_prediction_slurm(
271263    os .makedirs (output_folder , exist_ok = True )
272264    prediction_instances  =  int (prediction_instances )
273265    slurm_task_id  =  os .environ .get ("SLURM_ARRAY_TASK_ID" )
266+ 
274267    if  slurm_task_id  is  not   None :
275268        slurm_task_id  =  int (slurm_task_id )
269+     else :
270+         raise  ValueError ("The SLURM_ARRAY_TASK_ID is not set. Ensure that you are using the '-a' option with SBATCH." )
276271
277272    find_mask (input_path , input_key , output_folder )
278273
0 commit comments