Skip to content

Commit 3638985

Browse files
committed
update save as zarr
1 parent 2e004f1 commit 3638985

File tree

3 files changed

+232
-27
lines changed

3 files changed

+232
-27
lines changed

tests/engines/test_nucleus_detection_engine.py

Lines changed: 47 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,6 @@ def _rm_dir(path: pathlib.Path) -> None:
2323
shutil.rmtree(path, ignore_errors=True)
2424

2525

26-
def check_output(path: pathlib.Path) -> None:
27-
"""Check NucleusDetector output."""
28-
29-
3026
def test_nucleus_detector_wsi(remote_sample: Callable, tmp_path: pathlib.Path) -> None:
3127
"""Test for nucleus detection engine."""
3228
mini_wsi_svs = pathlib.Path(remote_sample("wsi4_512_512_svs"))
@@ -160,6 +156,7 @@ def test_nucleus_detector_patches_dict_output(
160156
save_dir=None,
161157
class_dict=None,
162158
)
159+
output_dict = output_dict["predictions"]
163160
assert len(output_dict["x"]) == 3
164161
assert len(output_dict["y"]) == 3
165162
assert len(output_dict["types"]) == 3
@@ -178,6 +175,52 @@ def test_nucleus_detector_patches_dict_output(
178175
assert len(output_dict["probs"][2]) == 0
179176

180177

178+
def test_nucleus_detector_patches_zarr_output(
179+
remote_sample: Callable, tmp_path: pathlib.Path
180+
) -> None:
181+
"""Test for nucleus detection engine in patch mode."""
182+
mini_wsi_svs = pathlib.Path(remote_sample("wsi4_512_512_svs"))
183+
184+
wsi_reader = WSIReader.open(mini_wsi_svs)
185+
patch_1 = wsi_reader.read_rect((0, 0), (252, 252), resolution=0.5, units="mpp")
186+
patch_2 = wsi_reader.read_rect((252, 252), (252, 252), resolution=0.5, units="mpp")
187+
patch_3 = np.zeros((252, 252, 3), dtype=np.uint8)
188+
189+
pretrained_model = "mapde-conic"
190+
191+
nucleus_detector = NucleusDetector(model=pretrained_model)
192+
193+
save_dir = tmp_path
194+
195+
output_path = nucleus_detector.run(
196+
patch_mode=True,
197+
device=device,
198+
output_type="zarr",
199+
memory_threshold=50,
200+
images=[patch_1, patch_2, patch_3],
201+
save_dir=save_dir,
202+
class_dict=None,
203+
overwrite=True,
204+
)
205+
206+
zarr_group = zarr.open(output_path, mode="r")
207+
output_dict = {
208+
"x": zarr_group["x"][:],
209+
"y": zarr_group["y"][:],
210+
"types": zarr_group["types"][:],
211+
"probs": zarr_group["probs"][:],
212+
"patch_offsets": zarr_group["patch_offsets"][:],
213+
}
214+
215+
assert len(output_dict["x"]) == 322
216+
assert len(output_dict["y"]) == 322
217+
assert len(output_dict["types"]) == 322
218+
assert len(output_dict["probs"]) == 322
219+
assert len(output_dict["patch_offsets"]) == 4
220+
221+
_rm_dir(save_dir)
222+
223+
181224
def test_centroid_maps_to_detection_arrays() -> None:
182225
"""Convert centroid maps to detection arrays."""
183226
detection_maps = np.zeros((4, 4, 2), dtype=np.float32)

tiatoolbox/models/architecture/sccnn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from collections import OrderedDict
1313
from typing import TYPE_CHECKING
1414

15-
if TYPE_CHECKING:
15+
if TYPE_CHECKING: # pragma: no cover
1616
import numpy as np
1717

1818
import torch

tiatoolbox/models/engine/nucleus_detector.py

Lines changed: 184 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
import dask
99
import dask.array as da
1010
import numpy as np
11+
import zarr
12+
from dask import compute
1113
from dask.diagnostics.progress import ProgressBar
1214
from shapely.geometry import Point
1315

@@ -17,13 +19,41 @@
1719
SemanticSegmentor,
1820
SemanticSegmentorRunParams,
1921
)
22+
from tiatoolbox.wsicore.wsireader import is_zarr
2023

2124
if TYPE_CHECKING: # pragma: no cover
2225
from typing import Unpack
2326

2427
from tiatoolbox.annotation import AnnotationStore
2528

2629

30+
def _flatten_predictions_to_dask(
31+
arr: da.Array | list[np.ndarray] | np.ndarray,
32+
) -> da.Array:
33+
"""Normalise predictions to a flat 1D Dask array."""
34+
# # Case 1: already a Dask array
35+
if isinstance(arr, da.Array):
36+
# If it's already a flat numeric Dask array, just return it
37+
if arr.dtype != object:
38+
return arr
39+
# Object-dtype Dask array: materialise then treat as list
40+
arr = arr.compute()
41+
42+
arr_list = list(arr)
43+
if arr_list is not None:
44+
if len(arr_list) == 0:
45+
flat_np = np.empty((0,), dtype=np.float32)
46+
return da.from_array(flat_np, chunks="auto")
47+
48+
dask_parts = [
49+
a if isinstance(a, da.Array) else da.from_array(a, chunks="auto")
50+
for a in arr_list
51+
]
52+
return da.concatenate(dask_parts, axis=0)
53+
54+
return da.from_array(arr, chunks="auto")
55+
56+
2757
class NucleusDetector(SemanticSegmentor):
2858
r"""Nucleus detection engine for digital pathology images.
2959
@@ -138,21 +168,16 @@ def block_fn(block: np.ndarray) -> np.ndarray:
138168
for i in range(postproc_maps.shape[0])
139169
]
140170

141-
def to_object_da(values: list[np.ndarray]) -> da.Array:
142-
"""Wrap list of numpy arrays into a single object-dtype dask array."""
143-
obj_array = np.array(values, dtype=object)
144-
return da.from_array(obj_array, chunks=(len(values),))
145-
146-
x_list = [det["x"] for det in detections]
147-
y_list = [det["y"] for det in detections]
148-
types_list = [det["types"] for det in detections]
149-
probs_list = [det["probs"] for det in detections]
171+
def to_object_da(arrs: list[da.Array]) -> da.Array:
172+
"""Wrap list of variable-length arrays into object-dtype dask array."""
173+
obj_array = np.array(arrs, dtype=object)
174+
return da.from_array(obj_array, chunks=(len(arrs),))
150175

151176
return {
152-
"x": to_object_da(x_list),
153-
"y": to_object_da(y_list),
154-
"types": to_object_da(types_list),
155-
"probs": to_object_da(probs_list),
177+
"x": to_object_da([det["x"] for det in detections]),
178+
"y": to_object_da([det["y"] for det in detections]),
179+
"types": to_object_da([det["types"] for det in detections]),
180+
"probs": to_object_da([det["probs"] for det in detections]),
156181
}
157182

158183
def post_process_wsi(
@@ -220,7 +245,7 @@ def save_predictions(
220245
output_type: str,
221246
save_path: Path | None = None,
222247
**kwargs: Unpack[SemanticSegmentorRunParams],
223-
) -> dict | AnnotationStore | Path:
248+
) -> dict | AnnotationStore | Path | list[Path]:
224249
"""Save nucleus detections to disk or return them in memory.
225250
226251
This method saves predictions in one of the supported formats:
@@ -255,13 +280,12 @@ def save_predictions(
255280
- returns AnnotationStore or path to .db file.
256281
257282
"""
258-
if output_type.lower() != "annotationstore":
259-
return super().save_predictions(
260-
processed_predictions["predictions"],
261-
output_type,
262-
save_path=save_path,
263-
**kwargs,
283+
if output_type.lower() not in ["dict", "zarr", "annotationstore"]:
284+
msg = (
285+
f"Unsupported output_type '{output_type}'. "
286+
"Supported types are 'dict', 'zarr', and 'annotationstore'."
264287
)
288+
raise ValueError(msg)
265289

266290
# scale_factor set from kwargs
267291
scale_factor = kwargs.get("scale_factor", (1.0, 1.0))
@@ -270,9 +294,147 @@ def save_predictions(
270294
if class_dict is None:
271295
class_dict = self.model.output_class_dict
272296

273-
# Need to add support for zarr conversion.
274-
save_paths = []
297+
if output_type.lower() == "dict":
298+
return super().save_predictions(
299+
processed_predictions,
300+
output_type,
301+
save_path=save_path,
302+
**kwargs,
303+
)
304+
if output_type.lower() == "annotationstore":
305+
return self._save_predictions_annotation_store(
306+
processed_predictions,
307+
save_path=save_path,
308+
scale_factor=scale_factor,
309+
class_dict=class_dict,
310+
)
311+
return self._save_predictions_zarr(
312+
processed_predictions,
313+
save_path=save_path,
314+
)
275315

316+
def _save_predictions_zarr(
317+
self: NucleusDetector,
318+
processed_predictions: dict,
319+
save_path: Path | None = None,
320+
) -> Path | list[Path]:
321+
"""Save predictions to a Zarr store.
322+
323+
Args:
324+
processed_predictions (dict):
325+
Dictionary containing processed model predictions.
326+
keys:
327+
- "predictions":
328+
{
329+
- 'x': dask array of x coordinates (np.uint32).
330+
- 'y': dask array of y coordinates (np.uint32).
331+
- 'types': dask array of detection types (np.uint32).
332+
- 'probs': dask array of detection probabilities (np.float32).
333+
}
334+
save_path (Path | None):
335+
Path to save the output Zarr store.
336+
337+
Returns:
338+
Path | list[Path]:
339+
Path to the saved Zarr store(s).
340+
341+
"""
342+
predictions = processed_predictions["predictions"]
343+
344+
keys_to_compute = [k for k in predictions if k not in self.drop_keys]
345+
346+
# If appending to an existing Zarr, skip keys that are already present
347+
if is_zarr(save_path):
348+
zarr_group = zarr.open(save_path, mode="r")
349+
keys_to_compute = [k for k in keys_to_compute if k not in zarr_group]
350+
351+
write_tasks = []
352+
353+
# --- NEW: compute patch_offsets from 'x' if we are in patch mode ----
354+
patch_offsets = None
355+
if self.patch_mode and "x" in predictions:
356+
x_arr_list = predictions["x"].compute()
357+
if x_arr_list is not None:
358+
# lengths[i] = number of detections in patch i
359+
lengths = np.array([len(a) for a in x_arr_list], dtype=np.int64)
360+
patch_offsets = np.empty(len(lengths) + 1, dtype=np.int64)
361+
patch_offsets[0] = 0
362+
np.cumsum(lengths, out=patch_offsets[1:])
363+
364+
# Save patch_offsets as its own 1D dataset
365+
offsets_da = da.from_array(patch_offsets, chunks="auto")
366+
write_tasks.append(
367+
offsets_da.to_zarr(
368+
url=save_path,
369+
component="patch_offsets",
370+
compute=False,
371+
)
372+
)
373+
374+
# ---------------- save flattened predictions -----------------
375+
for key in keys_to_compute:
376+
raw = predictions[key]
377+
378+
# Normalise ragged per-patch predictions to a flat 1D Dask array
379+
dask_array = _flatten_predictions_to_dask(raw)
380+
381+
# Type casting for storage
382+
if key != "probs":
383+
dask_array = dask_array.astype(np.uint32)
384+
else:
385+
dask_array = dask_array.astype(np.float32)
386+
387+
task = dask_array.to_zarr(
388+
url=save_path,
389+
component=key,
390+
compute=False,
391+
)
392+
write_tasks.append(task)
393+
394+
msg = f"Saving output to {save_path}."
395+
logger.info(msg=msg)
396+
with ProgressBar():
397+
compute(*write_tasks)
398+
399+
zarr_group = zarr.open(save_path, mode="r+")
400+
for key in self.drop_keys:
401+
if key in zarr_group:
402+
del zarr_group[key]
403+
404+
return save_path
405+
406+
def _save_predictions_annotation_store(
407+
self: NucleusDetector,
408+
processed_predictions: dict,
409+
save_path: Path | None = None,
410+
scale_factor: tuple[float, float] = (1.0, 1.0),
411+
class_dict: dict | None = None,
412+
) -> AnnotationStore | Path | list[Path]:
413+
"""Save predictions to an AnnotationStore.
414+
415+
Args:
416+
processed_predictions (dict):
417+
Dictionary containing processed model predictions.
418+
keys:
419+
- "predictions":
420+
{
421+
- 'x': dask array of x coordinates (np.uint32).
422+
- 'y': dask array of y coordinates (np.uint32).
423+
- 'types': dask array of detection types (np.uint32).
424+
- 'probs': dask array of detection probabilities (np.float32).
425+
}
426+
save_path (Path | None):
427+
Path to save the output file.
428+
scale_factor (tuple[float, float]):
429+
Scaling factors for x and y coordinates.
430+
class_dict (dict | None):
431+
Mapping from original class IDs to new class names.
432+
433+
Returns:
434+
AnnotationStore | Path:
435+
- returns AnnotationStore or path to .db file.
436+
437+
"""
276438
logger.info("Saving predictions as AnnotationStore.")
277439
if self.patch_mode:
278440
save_paths = []

0 commit comments

Comments
 (0)