@@ -71,19 +71,21 @@ def prediction_impl(input_path, input_key, output_folder, model_path, scale, blo
7171 input_ = ResizedVolume (input_ , shape = new_shape , order = 3 )
7272 image_mask = ResizedVolume (image_mask , new_shape , order = 0 )
7373
74- chunks = (128 , 128 , 128 )
75- block_shape = chunks
76-
7774 have_cuda = torch .cuda .is_available ()
78- assert have_cuda
75+
76+ if have_cuda :
77+ chunks = (128 , 128 , 128 )
78+
79+ if block_shape is None :
80+ block_shape = chunks if have_cuda else input_ .chunks
81+ if halo is None :
82+ halo = (16 , 32 , 32 )
7983 if have_cuda :
8084 print ("Predict with GPU" )
8185 gpu_ids = [0 ]
8286 else :
8387 print ("Predict with CPU" )
8488 gpu_ids = ["cpu" ]
85- if halo is None :
86- halo = (16 , 32 , 32 )
8789
8890 if None == mean or None == std :
8991 # Compute the global mean and standard deviation.
@@ -124,7 +126,7 @@ def postprocess(x):
124126 output = f .require_dataset (
125127 "prediction" ,
126128 shape = (3 ,) + input_ .shape ,
127- chunks = (1 ,) + chunks ,
129+ chunks = (1 ,) + block_shape ,
128130 compression = "gzip" ,
129131 dtype = "float32" ,
130132 )
@@ -328,7 +330,9 @@ def run_unet_prediction_slurm(
328330 std = None
329331
330332 original_shape = prediction_impl (
331- input_path , input_key , output_folder , model_path , scale , block_shape , halo , prediction_instances , slurm_task_id , mean = mean , std = std
333+ input_path , input_key , output_folder , model_path , scale , block_shape , halo ,
334+ prediction_instances = prediction_instances , slurm_task_id = slurm_task_id ,
335+ mean = mean , std = std ,
332336 )
333337
334338# does NOT need GPU, FIXME: only run on CPU
0 commit comments