Skip to content

Commit 2e004f1

Browse files
committed
refactor postprocessing and saving
1 parent a90f748 commit 2e004f1

File tree

2 files changed

+263
-246
lines changed

2 files changed

+263
-246
lines changed

tests/engines/test_nucleus_detection_engine.py

Lines changed: 81 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,7 @@
66

77
import dask.array as da
88
import numpy as np
9-
import pandas as pd
10-
import pytest
9+
import zarr
1110

1211
from tiatoolbox.annotation.storage import SQLiteStore
1312
from tiatoolbox.models.engine.nucleus_detector import NucleusDetector
@@ -28,91 +27,6 @@ def check_output(path: pathlib.Path) -> None:
2827
"""Check NucleusDetector output."""
2928

3029

31-
def test_nucleus_detection_nms_empty_dataframe() -> None:
32-
"""nucleus_detection_nms should return a copy for empty inputs."""
33-
df = pd.DataFrame(columns=["x", "y", "type", "prob"])
34-
35-
result = NucleusDetector.nucleus_detection_nms(df, radius=3)
36-
37-
assert result.empty
38-
assert result is not df
39-
assert list(result.columns) == ["x", "y", "type", "prob"]
40-
41-
42-
def test_nucleus_detection_nms_invalid_radius() -> None:
43-
"""Radius must be strictly positive."""
44-
df = pd.DataFrame({"x": [0], "y": [0], "type": [1], "prob": [0.9]})
45-
46-
with pytest.raises(ValueError, match="radius must be > 0"):
47-
NucleusDetector.nucleus_detection_nms(df, radius=0)
48-
49-
50-
def test_nucleus_detection_nms_invalid_overlap_threshold() -> None:
51-
"""overlap_threshold must lie in (0, 1]."""
52-
df = pd.DataFrame({"x": [0], "y": [0], "type": [1], "prob": [0.9]})
53-
54-
message = r"overlap_threshold must be in \(0\.0, 1\.0\], got 0"
55-
with pytest.raises(ValueError, match=message):
56-
NucleusDetector.nucleus_detection_nms(df, radius=1, overlap_threshold=0)
57-
58-
59-
def test_nucleus_detection_nms_suppresses_overlapping_detections() -> None:
60-
"""Lower-probability overlapping detections are removed."""
61-
df = pd.DataFrame(
62-
{
63-
"x": [2, 0, 20],
64-
"y": [1, 0, 20],
65-
"type": [1, 1, 2],
66-
"prob": [0.6, 0.9, 0.7],
67-
}
68-
)
69-
70-
result = NucleusDetector.nucleus_detection_nms(df, radius=5)
71-
72-
expected = pd.DataFrame(
73-
{"x": [0, 20], "y": [0, 20], "type": [1, 2], "prob": [0.9, 0.7]}
74-
)
75-
pd.testing.assert_frame_equal(result.reset_index(drop=True), expected)
76-
77-
78-
def test_nucleus_detection_nms_suppresses_across_types() -> None:
79-
"""Overlapping detections of different types are also suppressed."""
80-
df = pd.DataFrame(
81-
{
82-
"x": [0, 0, 20],
83-
"y": [0, 0, 0],
84-
"type": [1, 2, 1],
85-
"prob": [0.6, 0.95, 0.4],
86-
}
87-
)
88-
89-
result = NucleusDetector.nucleus_detection_nms(df, radius=5)
90-
91-
expected = pd.DataFrame(
92-
{"x": [0, 20], "y": [0, 0], "type": [2, 1], "prob": [0.95, 0.4]}
93-
)
94-
pd.testing.assert_frame_equal(result.reset_index(drop=True), expected)
95-
96-
97-
def test_nucleus_detection_nms_retains_non_overlapping_candidates() -> None:
98-
"""Detections with IoU below the threshold are preserved."""
99-
df = pd.DataFrame(
100-
{
101-
"x": [0, 10],
102-
"y": [0, 0],
103-
"type": [1, 1],
104-
"prob": [0.8, 0.5],
105-
}
106-
)
107-
108-
result = NucleusDetector.nucleus_detection_nms(df, radius=5, overlap_threshold=0.5)
109-
110-
expected = pd.DataFrame(
111-
{"x": [0, 10], "y": [0, 0], "type": [1, 1], "prob": [0.8, 0.5]}
112-
)
113-
pd.testing.assert_frame_equal(result.reset_index(drop=True), expected)
114-
115-
11630
def test_nucleus_detector_wsi(remote_sample: Callable, tmp_path: pathlib.Path) -> None:
11731
"""Test for nucleus detection engine."""
11832
mini_wsi_svs = pathlib.Path(remote_sample("wsi4_512_512_svs"))
@@ -136,10 +50,31 @@ def test_nucleus_detector_wsi(remote_sample: Callable, tmp_path: pathlib.Path) -
13650
assert len(store.values()) == 281
13751
store.close()
13852

53+
result_path = nucleus_detector.run(
54+
patch_mode=False,
55+
device=device,
56+
output_type="zarr",
57+
memory_threshold=50,
58+
images=[mini_wsi_svs],
59+
save_dir=save_dir,
60+
overwrite=True,
61+
)
62+
63+
zarr_path = result_path[mini_wsi_svs]
64+
zarr_group = zarr.open(zarr_path, mode="r")
65+
xs = zarr_group["x"][:]
66+
ys = zarr_group["y"][:]
67+
types = zarr_group["types"][:]
68+
probs = zarr_group["probs"][:]
69+
assert len(xs) == 281
70+
assert len(ys) == 281
71+
assert len(types) == 281
72+
assert len(probs) == 281
73+
13974
_rm_dir(save_dir)
14075

14176

142-
def test_nucleus_detector_patch(
77+
def test_nucleus_detector_patch_annotation_store_output(
14378
remote_sample: Callable, tmp_path: pathlib.Path
14479
) -> None:
14580
"""Test for nucleus detection engine in patch mode."""
@@ -183,7 +118,7 @@ def test_nucleus_detector_patch(
183118
_ = nucleus_detector.run(
184119
patch_mode=True,
185120
device=device,
186-
output_type="zarr",
121+
output_type="annotationstore",
187122
memory_threshold=50,
188123
images=[save_dir / "patch_0.png", save_dir / "patch_1.png"],
189124
save_dir=save_dir,
@@ -201,30 +136,63 @@ def test_nucleus_detector_patch(
201136
_rm_dir(save_dir)
202137

203138

204-
def test_nucleus_detector_write_centroid_maps(tmp_path: pathlib.Path) -> None:
205-
"""Test for _write_centroid_maps function."""
206-
detection_maps = np.zeros((20, 20, 1), dtype=np.uint8)
207-
detection_maps = da.from_array(detection_maps, chunks=(20, 20, 1))
139+
def test_nucleus_detector_patches_dict_output(
140+
remote_sample: Callable,
141+
) -> None:
142+
"""Test for nucleus detection engine in patch mode."""
143+
mini_wsi_svs = pathlib.Path(remote_sample("wsi4_512_512_svs"))
208144

209-
store = NucleusDetector.write_centroid_maps_to_store(
210-
detection_maps=detection_maps, class_dict=None
211-
)
212-
assert len(store.values()) == 0
213-
store.close()
145+
wsi_reader = WSIReader.open(mini_wsi_svs)
146+
patch_1 = wsi_reader.read_rect((0, 0), (252, 252), resolution=0.5, units="mpp")
147+
patch_2 = wsi_reader.read_rect((252, 252), (252, 252), resolution=0.5, units="mpp")
148+
patch_3 = np.zeros((252, 252, 3), dtype=np.uint8)
149+
150+
pretrained_model = "mapde-conic"
151+
152+
nucleus_detector = NucleusDetector(model=pretrained_model)
214153

215-
detection_maps = np.zeros((20, 20, 1), dtype=np.uint8)
216-
detection_maps[10, 10, 0] = 1
217-
detection_maps = da.from_array(detection_maps, chunks=(20, 20, 1))
218-
_ = NucleusDetector.write_centroid_maps_to_store(
219-
detection_maps=detection_maps,
220-
save_path=tmp_path / "test.db",
221-
class_dict={0: "nucleus"},
154+
output_dict = nucleus_detector.run(
155+
patch_mode=True,
156+
device=device,
157+
output_type="dict",
158+
memory_threshold=50,
159+
images=[patch_1, patch_2, patch_3],
160+
save_dir=None,
161+
class_dict=None,
222162
)
223-
store = SQLiteStore.open(tmp_path / "test.db")
224-
assert len(store.values()) == 1
225-
annotation = next(iter(store.values()))
226-
print(annotation)
227-
assert annotation.properties["type"] == "nucleus"
228-
assert annotation.geometry.centroid.x == 10.0
229-
assert annotation.geometry.centroid.y == 10.0
230-
store.close()
163+
assert len(output_dict["x"]) == 3
164+
assert len(output_dict["y"]) == 3
165+
assert len(output_dict["types"]) == 3
166+
assert len(output_dict["probs"]) == 3
167+
assert len(output_dict["x"][0]) == 270
168+
assert len(output_dict["x"][1]) == 52
169+
assert len(output_dict["x"][2]) == 0
170+
assert len(output_dict["y"][0]) == 270
171+
assert len(output_dict["y"][1]) == 52
172+
assert len(output_dict["y"][2]) == 0
173+
assert len(output_dict["types"][0]) == 270
174+
assert len(output_dict["types"][1]) == 52
175+
assert len(output_dict["types"][2]) == 0
176+
assert len(output_dict["probs"][0]) == 270
177+
assert len(output_dict["probs"][1]) == 52
178+
assert len(output_dict["probs"][2]) == 0
179+
180+
181+
def test_centroid_maps_to_detection_arrays() -> None:
182+
"""Convert centroid maps to detection arrays."""
183+
detection_maps = np.zeros((4, 4, 2), dtype=np.float32)
184+
detection_maps[1, 1, 0] = 1.0
185+
detection_maps[2, 3, 1] = 0.5
186+
detection_maps = da.from_array(detection_maps, chunks=(2, 2, 2))
187+
188+
detections = NucleusDetector._centroid_maps_to_detection_arrays(detection_maps)
189+
190+
xs = detections["x"]
191+
ys = detections["y"]
192+
types = detections["types"]
193+
probs = detections["probs"]
194+
195+
np.testing.assert_array_equal(xs, np.array([1, 3], dtype=np.uint32))
196+
np.testing.assert_array_equal(ys, np.array([1, 2], dtype=np.uint32))
197+
np.testing.assert_array_equal(types, np.array([0, 1], dtype=np.uint32))
198+
np.testing.assert_array_equal(probs, np.array([1.0, 0.5], dtype=np.float32))

0 commit comments

Comments
 (0)