@@ -84,7 +84,7 @@ def prediction_impl(input_path, input_key, output_folder, model_path, scale, blo
8484 print ("Predict with CPU" )
8585 gpu_ids = ["cpu" ]
8686
87- if None == mean or None == std :
87+ if mean is None or std is None :
8888 # Compute the global mean and standard deviation.
8989 n_threads = min (16 , mp .cpu_count ())
9090 mean , std = parallel .mean_and_std (
@@ -111,8 +111,7 @@ def postprocess(x):
111111
112112 blocking = nt .blocking ([0 ] * ndim , shape , block_shape )
113113 n_blocks = blocking .numberOfBlocks
114- iteration_ids = []
115- if 1 != prediction_instances :
114+ if prediction_instances != 1 :
116115 iteration_ids = [x .tolist () for x in np .array_split (list (range (n_blocks )), prediction_instances )]
117116 slurm_iteration = iteration_ids [slurm_task_id ]
118117 else :
@@ -264,7 +263,7 @@ def calc_mean_and_std(input_path, input_key, output_folder):
264263 input_ , block_shape = tuple ([2 * i for i in input_ .chunks ]), n_threads = n_threads , verbose = True ,
265264 mask = image_mask
266265 )
267- ddict = {"mean" :str ( mean ) , "std" : str ( std ) }
266+ ddict = {"mean" :mean , "std" :std }
268267 with open (json_file , "w" ) as f :
269268 json .dump (ddict , f )
270269
0 commit comments