Skip to content

Commit abf02f6

Browse files
committed
reduce code complexity
1 parent 86e809d commit abf02f6

File tree

2 files changed

+69
-75
lines changed

2 files changed

+69
-75
lines changed

tests/engines/test_nucleus_detection_engine.py

Lines changed: 36 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,10 @@
1010
import zarr
1111

1212
from tiatoolbox.annotation.storage import SQLiteStore
13-
from tiatoolbox.models.engine.nucleus_detector import NucleusDetector
13+
from tiatoolbox.models.engine.nucleus_detector import (
14+
NucleusDetector,
15+
_flatten_predictions_to_dask,
16+
)
1417
from tiatoolbox.utils import env_detection as toolbox_env
1518
from tiatoolbox.utils.misc import imwrite
1619
from tiatoolbox.wsicore.wsireader import WSIReader
@@ -43,34 +46,15 @@ def test_nucleus_detector_wsi(remote_sample: Callable, tmp_path: pathlib.Path) -
4346
save_dir=save_dir,
4447
overwrite=True,
4548
batch_size=8,
49+
class_dict={0: "test_nucleus"},
4650
)
4751

4852
store = SQLiteStore.open(save_dir / "wsi4_512_512.db")
4953
assert 255 <= len(store.values()) <= 265
54+
annotation = next(iter(store.values()))
55+
assert annotation.properties["type"] == "test_nucleus"
5056
store.close()
5157

52-
result_path = nucleus_detector.run(
53-
patch_mode=False,
54-
device=device,
55-
output_type="zarr",
56-
memory_threshold=50,
57-
images=[mini_wsi_svs],
58-
save_dir=save_dir,
59-
overwrite=True,
60-
batch_size=8,
61-
)
62-
63-
zarr_path = result_path[mini_wsi_svs]
64-
zarr_group = zarr.open(zarr_path, mode="r")
65-
xs = zarr_group["x"][:]
66-
ys = zarr_group["y"][:]
67-
types = zarr_group["types"][:]
68-
probs = zarr_group["probs"][:]
69-
assert 255 <= len(xs) <= 265
70-
assert 255 <= len(ys) <= 265
71-
assert 255 <= len(types) <= 265
72-
assert 255 <= len(probs) <= 265
73-
7458
nucleus_detector.drop_keys = ["probs"]
7559
result_path = nucleus_detector.run(
7660
patch_mode=False,
@@ -330,3 +314,32 @@ def test_write_detection_records_to_store_no_class_dict() -> None:
330314
annotation = next(iter(dummy_store.values()))
331315
assert annotation.properties["type"] == 0
332316
dummy_store.close()
317+
318+
319+
def test_flatten_predictions_to_dask() -> None:
320+
"""Test flattening ragged predictions to Dask array."""
321+
ragged_obj_array = np.empty(3, dtype=object)
322+
ragged_obj_array[0] = np.array([1.0, 0.0], dtype=np.float32)
323+
ragged_obj_array[1] = np.array([0.5, 0.5], dtype=np.float32)
324+
ragged_obj_array[2] = np.array([0.2, 0.8, 0.8, 0.2], dtype=np.float32)
325+
326+
ragged_da_array = da.from_array(ragged_obj_array, chunks=(len(ragged_obj_array),))
327+
328+
flat_dask_array = _flatten_predictions_to_dask(ragged_da_array)
329+
expected_array = np.array(
330+
[
331+
1.0,
332+
0.0,
333+
0.5,
334+
0.5,
335+
0.2,
336+
0.8,
337+
0.8,
338+
0.2,
339+
],
340+
dtype=np.float32,
341+
)
342+
np.testing.assert_array_equal(flat_dask_array.compute(), expected_array)
343+
344+
flat_dask_array = _flatten_predictions_to_dask(ragged_obj_array)
345+
np.testing.assert_array_equal(flat_dask_array.compute(), expected_array)

tiatoolbox/models/engine/nucleus_detector.py

Lines changed: 33 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,9 @@ def _flatten_predictions_to_dask(
3838
"""Normalise predictions to a flat 1D Dask array."""
3939
# # Case 1: already a Dask array
4040
if isinstance(arr, da.Array):
41-
# If it's already a flat numeric Dask array, just return it
41+
# If it's already a numeric Dask array, just return it
4242
if arr.dtype != object:
4343
return arr
44-
# Object-dtype Dask array: materialise then treat as list
4544
arr = arr.compute()
4645

4746
arr_list = list(arr)
@@ -134,37 +133,20 @@ def post_process_patches(
134133
- "probs": dask array of detection probabilities (np.float32).
135134
136135
"""
136+
logger.info("Post processing patch predictions in NucleusDetector")
137137
_ = kwargs.get("return_probabilities")
138138
_ = prediction_shape
139139
_ = prediction_dtype
140140

141-
# Ensure chunks are full in spatial/channel dims; batch dim can vary
142-
raw_predictions = raw_predictions.rechunk({0: 1})
143-
144-
def block_fn(block: np.ndarray) -> np.ndarray:
145-
"""Apply model's post-processing function to each block.
146-
147-
Args:
148-
block: (b_chunk, H, W, C) NumPy array representing a chunk of
149-
raw patch predictions.
150-
returns:
151-
Processed NumPy array after applying the model's post-processing.
152-
"""
153-
return np.stack(
154-
[self.model.postproc_func(sample) for sample in block], axis=0
141+
detection_arrays = []
142+
for i in range(raw_predictions.shape[0]):
143+
patch_pred = raw_predictions[i]
144+
postproc_map = da.from_array(
145+
self.model.postproc(patch_pred), chunks=patch_pred.chunks
146+
)
147+
detection_arrays.append(
148+
self._centroid_maps_to_detection_arrays(postproc_map)
155149
)
156-
157-
postproc_maps = da.map_blocks(
158-
block_fn,
159-
raw_predictions,
160-
dtype=raw_predictions.dtype,
161-
)
162-
163-
# Convert each patch's centroid map to detection records and aggregate
164-
detections = [
165-
self._centroid_maps_to_detection_arrays(postproc_maps[i])
166-
for i in range(postproc_maps.shape[0])
167-
]
168150

169151
def to_object_da(arrs: list[da.Array]) -> da.Array:
170152
"""Wrap list of variable-length arrays into object-dtype dask array."""
@@ -177,10 +159,10 @@ def to_object_da(arrs: list[da.Array]) -> da.Array:
177159
return da.from_array(obj_array, chunks=(len(arrs),))
178160

179161
return {
180-
"x": to_object_da([det["x"] for det in detections]),
181-
"y": to_object_da([det["y"] for det in detections]),
182-
"types": to_object_da([det["types"] for det in detections]),
183-
"probs": to_object_da([det["probs"] for det in detections]),
162+
"x": to_object_da([det["x"] for det in detection_arrays]),
163+
"y": to_object_da([det["y"] for det in detection_arrays]),
164+
"types": to_object_da([det["types"] for det in detection_arrays]),
165+
"probs": to_object_da([det["probs"] for det in detection_arrays]),
184166
}
185167

186168
def post_process_wsi(
@@ -212,10 +194,9 @@ def post_process_wsi(
212194
- "probs": dask array of detection probabilities.
213195
214196
"""
197+
_ = prediction_shape
198+
215199
logger.info("Post processing WSI predictions in NucleusDetector")
216-
logger.info("Raw probabilities shape: %s", prediction_shape)
217-
logger.info("Raw probabilities dtype %s", prediction_dtype)
218-
logger.info("Raw chunk size: %s", raw_predictions.chunks)
219200

220201
# Add halo (overlap) around each block for post-processing
221202
depth_h = self.model.min_distance
@@ -350,36 +331,36 @@ def _save_predictions_zarr(
350331
patch_offsets = None
351332
if self.patch_mode and "x" in predictions:
352333
x_arr_list = predictions["x"].compute()
353-
if x_arr_list is not None:
354-
# lengths[i] = number of detections in patch i
355-
lengths = np.array([len(a) for a in x_arr_list], dtype=np.int64)
356-
patch_offsets = np.empty(len(lengths) + 1, dtype=np.int64)
357-
patch_offsets[0] = 0
358-
np.cumsum(lengths, out=patch_offsets[1:])
359-
360-
# Save patch_offsets as its own 1D dataset
361-
offsets_da = da.from_array(patch_offsets, chunks="auto")
362-
write_tasks.append(
363-
offsets_da.to_zarr(
364-
url=save_path,
365-
component="patch_offsets",
366-
compute=False,
367-
)
334+
335+
# lengths[i] = number of detections in patch i
336+
lengths = np.array([len(a) for a in x_arr_list], dtype=np.int64)
337+
patch_offsets = np.empty(len(lengths) + 1, dtype=np.int64)
338+
patch_offsets[0] = 0
339+
np.cumsum(lengths, out=patch_offsets[1:])
340+
341+
# Save patch_offsets as its own 1D dataset
342+
offsets_da = da.from_array(patch_offsets, chunks="auto")
343+
write_tasks.append(
344+
offsets_da.to_zarr(
345+
url=save_path,
346+
component="patch_offsets",
347+
compute=False,
368348
)
349+
)
369350

370351
# ---------------- save flattened predictions -----------------
371352
for key in keys_to_compute:
372353
raw = predictions[key]
373354

374-
# Normalise ragged per-patch predictions to a flat 1D Dask array
375355
dask_array = _flatten_predictions_to_dask(raw)
376-
377356
# Type casting for storage
378357
if key != "probs":
379358
dask_array = dask_array.astype(np.uint32)
380359
else:
381360
dask_array = dask_array.astype(np.float32)
382361

362+
# Normalise ragged per-patch predictions to a flat 1D Dask array
363+
383364
task = dask_array.to_zarr(
384365
url=save_path,
385366
component=key,

0 commit comments

Comments
 (0)