88import dask
99import dask .array as da
1010import numpy as np
11+ import zarr
12+ from dask import compute
1113from dask .diagnostics .progress import ProgressBar
1214from shapely .geometry import Point
1315
1719 SemanticSegmentor ,
1820 SemanticSegmentorRunParams ,
1921)
22+ from tiatoolbox .wsicore .wsireader import is_zarr
2023
2124if TYPE_CHECKING : # pragma: no cover
2225 from typing import Unpack
2326
2427 from tiatoolbox .annotation import AnnotationStore
2528
2629
30+ def _flatten_predictions_to_dask (
31+ arr : da .Array | list [np .ndarray ] | np .ndarray ,
32+ ) -> da .Array :
33+ """Normalise predictions to a flat 1D Dask array."""
34+ # # Case 1: already a Dask array
35+ if isinstance (arr , da .Array ):
36+ # If it's already a flat numeric Dask array, just return it
37+ if arr .dtype != object :
38+ return arr
39+ # Object-dtype Dask array: materialise then treat as list
40+ arr = arr .compute ()
41+
42+ arr_list = list (arr )
43+ if arr_list is not None :
44+ if len (arr_list ) == 0 :
45+ flat_np = np .empty ((0 ,), dtype = np .float32 )
46+ return da .from_array (flat_np , chunks = "auto" )
47+
48+ dask_parts = [
49+ a if isinstance (a , da .Array ) else da .from_array (a , chunks = "auto" )
50+ for a in arr_list
51+ ]
52+ return da .concatenate (dask_parts , axis = 0 )
53+
54+ return da .from_array (arr , chunks = "auto" )
55+
56+
2757class NucleusDetector (SemanticSegmentor ):
2858 r"""Nucleus detection engine for digital pathology images.
2959
@@ -138,21 +168,16 @@ def block_fn(block: np.ndarray) -> np.ndarray:
138168 for i in range (postproc_maps .shape [0 ])
139169 ]
140170
141- def to_object_da (values : list [np .ndarray ]) -> da .Array :
142- """Wrap list of numpy arrays into a single object-dtype dask array."""
143- obj_array = np .array (values , dtype = object )
144- return da .from_array (obj_array , chunks = (len (values ),))
145-
146- x_list = [det ["x" ] for det in detections ]
147- y_list = [det ["y" ] for det in detections ]
148- types_list = [det ["types" ] for det in detections ]
149- probs_list = [det ["probs" ] for det in detections ]
171+ def to_object_da (arrs : list [da .Array ]) -> da .Array :
172+ """Wrap list of variable-length arrays into object-dtype dask array."""
173+ obj_array = np .array (arrs , dtype = object )
174+ return da .from_array (obj_array , chunks = (len (arrs ),))
150175
151176 return {
152- "x" : to_object_da (x_list ),
153- "y" : to_object_da (y_list ),
154- "types" : to_object_da (types_list ),
155- "probs" : to_object_da (probs_list ),
177+ "x" : to_object_da ([ det [ "x" ] for det in detections ] ),
178+ "y" : to_object_da ([ det [ "y" ] for det in detections ] ),
179+ "types" : to_object_da ([ det [ "types" ] for det in detections ] ),
180+ "probs" : to_object_da ([ det [ "probs" ] for det in detections ] ),
156181 }
157182
158183 def post_process_wsi (
@@ -220,7 +245,7 @@ def save_predictions(
220245 output_type : str ,
221246 save_path : Path | None = None ,
222247 ** kwargs : Unpack [SemanticSegmentorRunParams ],
223- ) -> dict | AnnotationStore | Path :
248+ ) -> dict | AnnotationStore | Path | list [ Path ] :
224249 """Save nucleus detections to disk or return them in memory.
225250
226251 This method saves predictions in one of the supported formats:
@@ -255,13 +280,12 @@ def save_predictions(
255280 - returns AnnotationStore or path to .db file.
256281
257282 """
258- if output_type .lower () != "annotationstore" :
259- return super ().save_predictions (
260- processed_predictions ["predictions" ],
261- output_type ,
262- save_path = save_path ,
263- ** kwargs ,
283+ if output_type .lower () not in ["dict" , "zarr" , "annotationstore" ]:
284+ msg = (
285+ f"Unsupported output_type '{ output_type } '. "
286+ "Supported types are 'dict', 'zarr', and 'annotationstore'."
264287 )
288+ raise ValueError (msg )
265289
266290 # scale_factor set from kwargs
267291 scale_factor = kwargs .get ("scale_factor" , (1.0 , 1.0 ))
@@ -270,9 +294,147 @@ def save_predictions(
270294 if class_dict is None :
271295 class_dict = self .model .output_class_dict
272296
273- # Need to add support for zarr conversion.
274- save_paths = []
297+ if output_type .lower () == "dict" :
298+ return super ().save_predictions (
299+ processed_predictions ,
300+ output_type ,
301+ save_path = save_path ,
302+ ** kwargs ,
303+ )
304+ if output_type .lower () == "annotationstore" :
305+ return self ._save_predictions_annotation_store (
306+ processed_predictions ,
307+ save_path = save_path ,
308+ scale_factor = scale_factor ,
309+ class_dict = class_dict ,
310+ )
311+ return self ._save_predictions_zarr (
312+ processed_predictions ,
313+ save_path = save_path ,
314+ )
275315
316+ def _save_predictions_zarr (
317+ self : NucleusDetector ,
318+ processed_predictions : dict ,
319+ save_path : Path | None = None ,
320+ ) -> Path | list [Path ]:
321+ """Save predictions to a Zarr store.
322+
323+ Args:
324+ processed_predictions (dict):
325+ Dictionary containing processed model predictions.
326+ keys:
327+ - "predictions":
328+ {
329+ - 'x': dask array of x coordinates (np.uint32).
330+ - 'y': dask array of y coordinates (np.uint32).
331+ - 'types': dask array of detection types (np.uint32).
332+ - 'probs': dask array of detection probabilities (np.float32).
333+ }
334+ save_path (Path | None):
335+ Path to save the output Zarr store.
336+
337+ Returns:
338+ Path | list[Path]:
339+ Path to the saved Zarr store(s).
340+
341+ """
342+ predictions = processed_predictions ["predictions" ]
343+
344+ keys_to_compute = [k for k in predictions if k not in self .drop_keys ]
345+
346+ # If appending to an existing Zarr, skip keys that are already present
347+ if is_zarr (save_path ):
348+ zarr_group = zarr .open (save_path , mode = "r" )
349+ keys_to_compute = [k for k in keys_to_compute if k not in zarr_group ]
350+
351+ write_tasks = []
352+
353+ # --- NEW: compute patch_offsets from 'x' if we are in patch mode ----
354+ patch_offsets = None
355+ if self .patch_mode and "x" in predictions :
356+ x_arr_list = predictions ["x" ].compute ()
357+ if x_arr_list is not None :
358+ # lengths[i] = number of detections in patch i
359+ lengths = np .array ([len (a ) for a in x_arr_list ], dtype = np .int64 )
360+ patch_offsets = np .empty (len (lengths ) + 1 , dtype = np .int64 )
361+ patch_offsets [0 ] = 0
362+ np .cumsum (lengths , out = patch_offsets [1 :])
363+
364+ # Save patch_offsets as its own 1D dataset
365+ offsets_da = da .from_array (patch_offsets , chunks = "auto" )
366+ write_tasks .append (
367+ offsets_da .to_zarr (
368+ url = save_path ,
369+ component = "patch_offsets" ,
370+ compute = False ,
371+ )
372+ )
373+
374+ # ---------------- save flattened predictions -----------------
375+ for key in keys_to_compute :
376+ raw = predictions [key ]
377+
378+ # Normalise ragged per-patch predictions to a flat 1D Dask array
379+ dask_array = _flatten_predictions_to_dask (raw )
380+
381+ # Type casting for storage
382+ if key != "probs" :
383+ dask_array = dask_array .astype (np .uint32 )
384+ else :
385+ dask_array = dask_array .astype (np .float32 )
386+
387+ task = dask_array .to_zarr (
388+ url = save_path ,
389+ component = key ,
390+ compute = False ,
391+ )
392+ write_tasks .append (task )
393+
394+ msg = f"Saving output to { save_path } ."
395+ logger .info (msg = msg )
396+ with ProgressBar ():
397+ compute (* write_tasks )
398+
399+ zarr_group = zarr .open (save_path , mode = "r+" )
400+ for key in self .drop_keys :
401+ if key in zarr_group :
402+ del zarr_group [key ]
403+
404+ return save_path
405+
406+ def _save_predictions_annotation_store (
407+ self : NucleusDetector ,
408+ processed_predictions : dict ,
409+ save_path : Path | None = None ,
410+ scale_factor : tuple [float , float ] = (1.0 , 1.0 ),
411+ class_dict : dict | None = None ,
412+ ) -> AnnotationStore | Path | list [Path ]:
413+ """Save predictions to an AnnotationStore.
414+
415+ Args:
416+ processed_predictions (dict):
417+ Dictionary containing processed model predictions.
418+ keys:
419+ - "predictions":
420+ {
421+ - 'x': dask array of x coordinates (np.uint32).
422+ - 'y': dask array of y coordinates (np.uint32).
423+ - 'types': dask array of detection types (np.uint32).
424+ - 'probs': dask array of detection probabilities (np.float32).
425+ }
426+ save_path (Path | None):
427+ Path to save the output file.
428+ scale_factor (tuple[float, float]):
429+ Scaling factors for x and y coordinates.
430+ class_dict (dict | None):
431+ Mapping from original class IDs to new class names.
432+
433+ Returns:
434+ AnnotationStore | Path:
435+ - returns AnnotationStore or path to .db file.
436+
437+ """
276438 logger .info ("Saving predictions as AnnotationStore." )
277439 if self .patch_mode :
278440 save_paths = []
0 commit comments