@@ -38,10 +38,9 @@ def _flatten_predictions_to_dask(
3838 """Normalise predictions to a flat 1D Dask array."""
3939 # # Case 1: already a Dask array
4040 if isinstance (arr , da .Array ):
41- # If it's already a flat numeric Dask array, just return it
41+ # If it's already a numeric Dask array, just return it
4242 if arr .dtype != object :
4343 return arr
44- # Object-dtype Dask array: materialise then treat as list
4544 arr = arr .compute ()
4645
4746 arr_list = list (arr )
@@ -134,37 +133,20 @@ def post_process_patches(
134133 - "probs": dask array of detection probabilities (np.float32).
135134
136135 """
136+ logger .info ("Post processing patch predictions in NucleusDetector" )
137137 _ = kwargs .get ("return_probabilities" )
138138 _ = prediction_shape
139139 _ = prediction_dtype
140140
141- # Ensure chunks are full in spatial/channel dims; batch dim can vary
142- raw_predictions = raw_predictions .rechunk ({0 : 1 })
143-
144- def block_fn (block : np .ndarray ) -> np .ndarray :
145- """Apply model's post-processing function to each block.
146-
147- Args:
148- block: (b_chunk, H, W, C) NumPy array representing a chunk of
149- raw patch predictions.
150- returns:
151- Processed NumPy array after applying the model's post-processing.
152- """
153- return np .stack (
154- [self .model .postproc_func (sample ) for sample in block ], axis = 0
141+ detection_arrays = []
142+ for i in range (raw_predictions .shape [0 ]):
143+ patch_pred = raw_predictions [i ]
144+ postproc_map = da .from_array (
145+ self .model .postproc (patch_pred ), chunks = patch_pred .chunks
146+ )
147+ detection_arrays .append (
148+ self ._centroid_maps_to_detection_arrays (postproc_map )
155149 )
156-
157- postproc_maps = da .map_blocks (
158- block_fn ,
159- raw_predictions ,
160- dtype = raw_predictions .dtype ,
161- )
162-
163- # Convert each patch's centroid map to detection records and aggregate
164- detections = [
165- self ._centroid_maps_to_detection_arrays (postproc_maps [i ])
166- for i in range (postproc_maps .shape [0 ])
167- ]
168150
169151 def to_object_da (arrs : list [da .Array ]) -> da .Array :
170152 """Wrap list of variable-length arrays into object-dtype dask array."""
@@ -177,10 +159,10 @@ def to_object_da(arrs: list[da.Array]) -> da.Array:
177159 return da .from_array (obj_array , chunks = (len (arrs ),))
178160
179161 return {
180- "x" : to_object_da ([det ["x" ] for det in detections ]),
181- "y" : to_object_da ([det ["y" ] for det in detections ]),
182- "types" : to_object_da ([det ["types" ] for det in detections ]),
183- "probs" : to_object_da ([det ["probs" ] for det in detections ]),
162+ "x" : to_object_da ([det ["x" ] for det in detection_arrays ]),
163+ "y" : to_object_da ([det ["y" ] for det in detection_arrays ]),
164+ "types" : to_object_da ([det ["types" ] for det in detection_arrays ]),
165+ "probs" : to_object_da ([det ["probs" ] for det in detection_arrays ]),
184166 }
185167
186168 def post_process_wsi (
@@ -212,10 +194,9 @@ def post_process_wsi(
212194 - "probs": dask array of detection probabilities.
213195
214196 """
197+ _ = prediction_shape
198+
215199 logger .info ("Post processing WSI predictions in NucleusDetector" )
216- logger .info ("Raw probabilities shape: %s" , prediction_shape )
217- logger .info ("Raw probabilities dtype %s" , prediction_dtype )
218- logger .info ("Raw chunk size: %s" , raw_predictions .chunks )
219200
220201 # Add halo (overlap) around each block for post-processing
221202 depth_h = self .model .min_distance
@@ -350,36 +331,36 @@ def _save_predictions_zarr(
350331 patch_offsets = None
351332 if self .patch_mode and "x" in predictions :
352333 x_arr_list = predictions ["x" ].compute ()
353- if x_arr_list is not None :
354- # lengths[i] = number of detections in patch i
355- lengths = np .array ([len (a ) for a in x_arr_list ], dtype = np .int64 )
356- patch_offsets = np .empty (len (lengths ) + 1 , dtype = np .int64 )
357- patch_offsets [0 ] = 0
358- np .cumsum (lengths , out = patch_offsets [1 :])
359-
360- # Save patch_offsets as its own 1D dataset
361- offsets_da = da .from_array (patch_offsets , chunks = "auto" )
362- write_tasks .append (
363- offsets_da .to_zarr (
364- url = save_path ,
365- component = "patch_offsets" ,
366- compute = False ,
367- )
334+
335+ # lengths[i] = number of detections in patch i
336+ lengths = np .array ([len (a ) for a in x_arr_list ], dtype = np .int64 )
337+ patch_offsets = np .empty (len (lengths ) + 1 , dtype = np .int64 )
338+ patch_offsets [0 ] = 0
339+ np .cumsum (lengths , out = patch_offsets [1 :])
340+
341+ # Save patch_offsets as its own 1D dataset
342+ offsets_da = da .from_array (patch_offsets , chunks = "auto" )
343+ write_tasks .append (
344+ offsets_da .to_zarr (
345+ url = save_path ,
346+ component = "patch_offsets" ,
347+ compute = False ,
368348 )
349+ )
369350
370351 # ---------------- save flattened predictions -----------------
371352 for key in keys_to_compute :
372353 raw = predictions [key ]
373354
374- # Normalise ragged per-patch predictions to a flat 1D Dask array
375355 dask_array = _flatten_predictions_to_dask (raw )
376-
377356 # Type casting for storage
378357 if key != "probs" :
379358 dask_array = dask_array .astype (np .uint32 )
380359 else :
381360 dask_array = dask_array .astype (np .float32 )
382361
362+ # Normalise ragged per-patch predictions to a flat 1D Dask array
363+
383364 task = dask_array .to_zarr (
384365 url = save_path ,
385366 component = key ,
0 commit comments