@@ -67,6 +67,7 @@ def prediction_impl(
6767 slurm_task_id = 0 ,
6868 mean = None ,
6969 std = None ,
70+ mask = None
7071):
7172 """@private
7273 """
@@ -79,18 +80,20 @@ def prediction_impl(
7980
8081 input_ = read_image_data (input_path , input_key )
8182 chunks = getattr (input_ , "chunks" , (64 , 64 , 64 ))
82- mask_path = os .path .join (output_folder , "mask.zarr" )
83-
84- if os .path .exists (mask_path ):
85- image_mask = z5py .File (mask_path , "r" )["mask" ]
86- # resize mask
87- image_shape = input_ .shape
88- mask_shape = image_mask .shape
89- if image_shape != mask_shape :
90- image_mask = ResizedVolume (image_mask , image_shape , order = 0 )
9183
84+ if output_folder is None :
85+ image_mask = mask
9286 else :
93- image_mask = None
87+ mask_path = os .path .join (output_folder , "mask.zarr" )
88+ if os .path .exists (mask_path ):
89+ image_mask = z5py .File (mask_path , "r" )["mask" ]
90+ # resize mask
91+ image_shape = input_ .shape
92+ mask_shape = image_mask .shape
93+ if image_shape != mask_shape :
94+ image_mask = ResizedVolume (image_mask , image_shape , order = 0 )
95+ else :
96+ image_mask = mask
9497
9598 if scale is None or np .isclose (scale , 1 ):
9699 original_shape = None
@@ -162,16 +165,8 @@ def postprocess(x):
162165 else :
163166 slurm_iteration = list (range (n_blocks ))
164167
165- output_path = os .path .join (output_folder , "predictions.zarr" )
166- with open_file (output_path , "a" ) as f :
167- output = f .require_dataset (
168- "prediction" ,
169- shape = output_shape ,
170- chunks = output_chunks ,
171- compression = "gzip" ,
172- dtype = "float32" ,
173- )
174-
168+ if output_folder is None :
169+ output = np .zeros (output_shape , dtype = np .float32 )
175170 predict_with_halo (
176171 input_ , model ,
177172 gpu_ids = gpu_ids , block_shape = block_shape , halo = halo ,
@@ -180,10 +175,37 @@ def postprocess(x):
180175 iter_list = slurm_iteration ,
181176 )
182177
183- return original_shape
178+ else :
179+ output_path = os .path .join (output_folder , "predictions.zarr" )
180+ with open_file (output_path , "a" ) as f :
181+ output = f .require_dataset (
182+ "prediction" ,
183+ shape = output_shape ,
184+ chunks = output_chunks ,
185+ compression = "gzip" ,
186+ dtype = "float32" ,
187+ )
188+
189+ predict_with_halo (
190+ input_ , model ,
191+ gpu_ids = gpu_ids , block_shape = block_shape , halo = halo ,
192+ output = output , preprocess = preprocess , postprocess = postprocess ,
193+ mask = image_mask ,
194+ iter_list = slurm_iteration ,
195+ )
196+
197+ if output_folder is None :
198+ return original_shape , output
199+ else :
200+ return original_shape , None
184201
185202
186- def find_mask (input_path : str , input_key : Optional [str ], output_folder : str , seg_class : Optional [str ] = "sgn" ) -> None :
203+ def find_mask (
204+ input_path : str ,
205+ input_key : Optional [str ],
206+ output_folder : Optional [str ],
207+ seg_class : Optional [str ] = "sgn"
208+ ) -> None :
187209 """Determine the mask for running prediction.
188210
189211 The mask corresponds to data that contains actual signal and not just noise.
@@ -197,9 +219,6 @@ def find_mask(input_path: str, input_key: Optional[str], output_folder: str, seg
197219 output_folder: The output folder for storing the mask data.
198220 seg_class: Specifier for exclusion criterias for mask generation.
199221 """
200- mask_path = os .path .join (output_folder , "mask.zarr" )
201- f = z5py .File (mask_path , "a" )
202-
203222 # set parameters for the exclusion of chunks within mask generation
204223 if seg_class == "sgn" :
205224 upper_percentile = 95
@@ -214,18 +233,24 @@ def find_mask(input_path: str, input_key: Optional[str], output_folder: str, seg
214233 min_intensity = 200
215234 print ("Calculating mask with default values." )
216235
217- mask_key = "mask"
218- if mask_key in f :
219- return
220-
221236 raw = read_image_data (input_path , input_key )
222237 chunks = getattr (raw , "chunks" , (64 , 64 , 64 ))
223238
224239 block_shape = tuple (2 * ch for ch in chunks )
225240 blocking = nt .blocking ([0 , 0 , 0 ], raw .shape , block_shape )
226241 n_blocks = blocking .numberOfBlocks
227242
228- ds_mask = f .create_dataset (mask_key , shape = raw .shape , compression = "gzip" , dtype = "uint8" , chunks = block_shape )
243+ if output_folder is None :
244+ ds_mask = np .zeros (raw .shape , dtype = np .uint64 )
245+
246+ else :
247+ mask_path = os .path .join (output_folder , "mask.zarr" )
248+ f = z5py .File (mask_path , "a" )
249+ mask_key = "mask"
250+ if mask_key in f :
251+ return
252+
253+ ds_mask = f .create_dataset (mask_key , shape = raw .shape , compression = "gzip" , dtype = "uint8" , chunks = block_shape )
229254
230255 # TODO more sophisticated criterion?!
231256 def find_mask_block (block_id ):
@@ -240,15 +265,20 @@ def find_mask_block(block_id):
240265 with futures .ThreadPoolExecutor (n_threads ) as tp :
241266 list (tqdm (tp .map (find_mask_block , range (n_blocks )), total = n_blocks ))
242267
268+ if output_folder is None :
269+ return ds_mask
270+ else :
271+ return None
272+
243273
244274def distance_watershed_implementation (
245275 input_path : str ,
246- output_folder : str ,
247- min_size : int ,
276+ output_folder : Optional [ str ] = None ,
277+ min_size : int = 1000 ,
248278 center_distance_threshold : float = 0.4 ,
249279 boundary_distance_threshold : Optional [float ] = None ,
250280 fg_threshold : float = 0.5 ,
251- original_shape : Optional [Tuple [int , int , int ]] = None ,
281+ original_shape : Optional [Tuple [int , int , int ]] = None
252282) -> None :
253283 """Parallel implementation of the distance-prediction based watershed.
254284
@@ -262,7 +292,10 @@ def distance_watershed_implementation(
262292 fg_threshold: The threshold applied to the foreground prediction for deriving the watershed mask.
263293 original_shape: The original shape to resize the segmentation to.
264294 """
265- input_ = open_file (input_path , "r" )["prediction" ]
295+ if isinstance (input_path , str ):
296+ input_ = open_file (input_path , "r" )["prediction" ]
297+ else :
298+ input_ = input_path
266299
267300 # Limit the number of cores for parallelization.
268301 n_threads = min (16 , mp .cpu_count ())
@@ -280,13 +313,17 @@ def distance_watershed_implementation(
280313 # center_distances = SimpleTransformationWrapper(center_distances, transformation=smoothing)
281314 # boundary_distances = SimpleTransformationWrapper(boundary_distances, transformation=smoothing)
282315
283- # Allocate an zarr array for the seeds.
284- block_shape = center_distances .chunks
285- seed_path = os .path .join (output_folder , "seeds.zarr" )
286- seed_file = open_file (os .path .join (seed_path ), "a" )
287- seeds = seed_file .require_dataset (
288- "seeds" , shape = center_distances .shape , chunks = block_shape , compression = "gzip" , dtype = "uint64"
289- )
316+ # Allocate the (zarr) array for the seeds.
317+ if output_folder is None :
318+ block_shape = (20 , 128 , 128 )
319+ seeds = np .zeros (center_distances .shape , dtype = np .uint64 )
320+ else :
321+ block_shape = center_distances .chunks
322+ seed_path = os .path .join (output_folder , "seeds.zarr" )
323+ seed_file = open_file (os .path .join (seed_path ), "a" )
324+ seeds = seed_file .require_dataset (
325+ "seeds" , shape = center_distances .shape , chunks = block_shape , compression = "gzip" , dtype = "uint64"
326+ )
290327
291328 # Compute the seed inputs:
292329 # First, threshold the center distances.
@@ -301,12 +338,15 @@ def distance_watershed_implementation(
301338 data = seed_inputs , out = seeds , block_shape = block_shape , mask = mask , verbose = True , n_threads = n_threads
302339 )
303340
304- # Allocate the zarr array for the segmentation.
305- seg_path = os .path .join (output_folder , "segmentation.zarr" if original_shape is None else "seg_downscaled.zarr" )
306- seg_file = open_file (seg_path , "a" )
307- seg = seg_file .create_dataset (
308- "segmentation" , shape = seeds .shape , chunks = block_shape , compression = "gzip" , dtype = "uint64"
309- )
341+ # Allocate the (zarr) array for the segmentation.
342+ if output_folder is None :
343+ seg = np .zeros (seeds .shape , dtype = np .uint64 )
344+ else :
345+ seg_path = os .path .join (output_folder , "segmentation.zarr" if original_shape is None else "seg_downscaled.zarr" )
346+ seg_file = open_file (seg_path , "a" )
347+ seg = seg_file .create_dataset (
348+ "segmentation" , shape = seeds .shape , chunks = block_shape , compression = "gzip" , dtype = "uint64"
349+ )
310350
311351 # Compute the segmentation with a seeded watershed
312352 halo = (2 , 8 , 8 )
@@ -341,6 +381,11 @@ def write_block(block_id):
341381 with futures .ThreadPoolExecutor (n_threads ) as tp :
342382 tp .map (write_block , range (blocking .numberOfBlocks ))
343383
384+ if output_folder is None :
385+ return seg
386+ else :
387+ return None
388+
344389
345390def calc_mean_and_std (input_path : str , input_key : str , output_folder : str ) -> None :
346391 """Calculate mean and standard deviation of the input volume.
@@ -372,7 +417,7 @@ def calc_mean_and_std(input_path: str, input_key: str, output_folder: str) -> No
372417def run_unet_prediction (
373418 input_path : str ,
374419 input_key : Optional [str ],
375- output_folder : str ,
420+ output_folder : Optional [ str ] ,
376421 model_path : str ,
377422 min_size : int ,
378423 scale : Optional [float ] = None ,
@@ -403,22 +448,33 @@ def run_unet_prediction(
403448 fg_threshold: The threshold applied to the foreground prediction for deriving the watershed mask.
404449 seg_class: Specifier for exclusion criterias for mask generation.
405450 """
406- os .makedirs (output_folder , exist_ok = True )
451+ if output_folder is not None :
452+ os .makedirs (output_folder , exist_ok = True )
407453
408454 if use_mask :
409- find_mask (input_path , input_key , output_folder , seg_class = seg_class )
410- original_shape = prediction_impl (
411- input_path , input_key , output_folder , model_path , scale , block_shape , halo
455+ mask = find_mask (input_path , input_key , output_folder = output_folder , seg_class = seg_class )
456+ else :
457+ mask = None
458+
459+ original_shape , prediction = prediction_impl (
460+ input_path = input_path , input_key = input_key , output_folder = output_folder , model_path = model_path , scale = scale ,
461+ block_shape = block_shape , halo = halo , mask = mask
412462 )
413463
414- pmap_out = os .path .join (output_folder , "predictions.zarr" )
415- distance_watershed_implementation (
464+ if output_folder is None :
465+ pmap_out = prediction
466+ else :
467+ pmap_out = os .path .join (output_folder , "predictions.zarr" )
468+
469+ segmentation = distance_watershed_implementation (
416470 pmap_out , output_folder , min_size = min_size , original_shape = original_shape ,
417471 center_distance_threshold = center_distance_threshold ,
418472 boundary_distance_threshold = boundary_distance_threshold ,
419- fg_threshold = fg_threshold ,
473+ fg_threshold = fg_threshold
420474 )
421475
476+ return segmentation
477+
422478
423479#
424480# ---Workflow for parallel prediction using slurm---
0 commit comments