Skip to content

Commit 0a72e8b

Browse files
committed
improve tests
1 parent 7912abe commit 0a72e8b

File tree

3 files changed

+178
-5
lines changed

3 files changed

+178
-5
lines changed

tests/engines/test_nucleus_detection_engine.py

Lines changed: 128 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,13 @@
44
import shutil
55
from collections.abc import Callable
66

7+
import pandas as pd
8+
import pytest
9+
710
from tiatoolbox.annotation.storage import SQLiteStore
811
from tiatoolbox.models.engine.nucleus_detector import NucleusDetector
912
from tiatoolbox.utils import env_detection as toolbox_env
13+
from tiatoolbox.wsicore.wsireader import WSIReader
1014

1115
device = "cuda" if toolbox_env.has_gpu() else "cpu"
1216

@@ -19,8 +23,91 @@ def _rm_dir(path: pathlib.Path) -> None:
1923

2024
def check_output(path: pathlib.Path) -> None:
2125
"""Check NucleusDetector output."""
22-
store = SQLiteStore.open(path)
23-
assert len(store.values()) == 281
26+
27+
28+
def test_nucleus_detection_nms_empty_dataframe() -> None:
29+
"""nucleus_detection_nms should return a copy for empty inputs."""
30+
df = pd.DataFrame(columns=["x", "y", "type", "prob"])
31+
32+
result = NucleusDetector.nucleus_detection_nms(df, radius=3)
33+
34+
assert result.empty
35+
assert result is not df
36+
assert list(result.columns) == ["x", "y", "type", "prob"]
37+
38+
39+
def test_nucleus_detection_nms_invalid_radius() -> None:
40+
"""Radius must be strictly positive."""
41+
df = pd.DataFrame({"x": [0], "y": [0], "type": [1], "prob": [0.9]})
42+
43+
with pytest.raises(ValueError, match="radius must be > 0"):
44+
NucleusDetector.nucleus_detection_nms(df, radius=0)
45+
46+
47+
def test_nucleus_detection_nms_invalid_overlap_threshold() -> None:
48+
"""overlap_threshold must lie in (0, 1]."""
49+
df = pd.DataFrame({"x": [0], "y": [0], "type": [1], "prob": [0.9]})
50+
51+
message = r"overlap_threshold must be in \(0\.0, 1\.0\], got 0"
52+
with pytest.raises(ValueError, match=message):
53+
NucleusDetector.nucleus_detection_nms(df, radius=1, overlap_threshold=0)
54+
55+
56+
def test_nucleus_detection_nms_suppresses_overlapping_detections() -> None:
57+
"""Lower-probability overlapping detections are removed."""
58+
df = pd.DataFrame(
59+
{
60+
"x": [2, 0, 20],
61+
"y": [1, 0, 20],
62+
"type": [1, 1, 2],
63+
"prob": [0.6, 0.9, 0.7],
64+
}
65+
)
66+
67+
result = NucleusDetector.nucleus_detection_nms(df, radius=5)
68+
69+
expected = pd.DataFrame(
70+
{"x": [0, 20], "y": [0, 20], "type": [1, 2], "prob": [0.9, 0.7]}
71+
)
72+
pd.testing.assert_frame_equal(result.reset_index(drop=True), expected)
73+
74+
75+
def test_nucleus_detection_nms_suppresses_across_types() -> None:
76+
"""Overlapping detections of different types are also suppressed."""
77+
df = pd.DataFrame(
78+
{
79+
"x": [0, 0, 20],
80+
"y": [0, 0, 0],
81+
"type": [1, 2, 1],
82+
"prob": [0.6, 0.95, 0.4],
83+
}
84+
)
85+
86+
result = NucleusDetector.nucleus_detection_nms(df, radius=5)
87+
88+
expected = pd.DataFrame(
89+
{"x": [0, 20], "y": [0, 0], "type": [2, 1], "prob": [0.95, 0.4]}
90+
)
91+
pd.testing.assert_frame_equal(result.reset_index(drop=True), expected)
92+
93+
94+
def test_nucleus_detection_nms_retains_non_overlapping_candidates() -> None:
95+
"""Detections with IoU below the threshold are preserved."""
96+
df = pd.DataFrame(
97+
{
98+
"x": [0, 10],
99+
"y": [0, 0],
100+
"type": [1, 1],
101+
"prob": [0.8, 0.5],
102+
}
103+
)
104+
105+
result = NucleusDetector.nucleus_detection_nms(df, radius=5, overlap_threshold=0.5)
106+
107+
expected = pd.DataFrame(
108+
{"x": [0, 10], "y": [0, 0], "type": [1, 1], "prob": [0.8, 0.5]}
109+
)
110+
pd.testing.assert_frame_equal(result.reset_index(drop=True), expected)
24111

25112

26113
def test_nucleus_detector_wsi(remote_sample: Callable, tmp_path: pathlib.Path) -> None:
@@ -36,13 +123,50 @@ def test_nucleus_detector_wsi(remote_sample: Callable, tmp_path: pathlib.Path) -
36123
patch_mode=False,
37124
device=device,
38125
output_type="annotationstore",
39-
auto_get_mask=True,
40126
memory_threshold=50,
41127
images=[mini_wsi_svs],
42128
save_dir=save_dir,
43129
overwrite=True,
44130
)
45131

46-
check_output(save_dir / "wsi4_512_512.db")
132+
store = SQLiteStore.open(save_dir / "wsi4_512_512.db")
133+
assert len(store.values()) == 281
134+
store.close()
135+
136+
_rm_dir(save_dir)
137+
138+
139+
def test_nucleus_detector_patch(
140+
remote_sample: Callable, tmp_path: pathlib.Path
141+
) -> None:
142+
"""Test for nucleus detection engine in patch mode."""
143+
mini_wsi_svs = pathlib.Path(remote_sample("wsi4_512_512_svs"))
144+
145+
wsi_reader = WSIReader.open(mini_wsi_svs)
146+
patch_1 = wsi_reader.read_rect((0, 0), (252, 252), resolution=0.5, units="mpp")
147+
patch_2 = wsi_reader.read_rect((252, 252), (252, 252), resolution=0.5, units="mpp")
148+
149+
pretrained_model = "mapde-conic"
150+
151+
save_dir = tmp_path
152+
153+
nucleus_detector = NucleusDetector(model=pretrained_model)
154+
_ = nucleus_detector.run(
155+
patch_mode=True,
156+
device=device,
157+
output_type="annotationstore",
158+
memory_threshold=50,
159+
images=[patch_1, patch_2],
160+
save_dir=save_dir,
161+
overwrite=True,
162+
)
163+
164+
store_1 = SQLiteStore.open(save_dir / "0.db")
165+
assert len(store_1.values()) == 270
166+
store_1.close()
167+
168+
store_2 = SQLiteStore.open(save_dir / "1.db")
169+
assert len(store_2.values()) == 52
170+
store_2.close()
47171

48172
_rm_dir(save_dir)

tests/models/test_arch_mapde.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,31 @@ def test_functionality(remote_sample: Callable) -> None:
5353

5454
np.testing.assert_array_equal(xs[0:2], np.array([242, 192]))
5555
np.testing.assert_array_equal(ys[0:2], np.array([10, 13]))
56+
57+
patch = reader.read_bounds(
58+
(0, 0, 252, 252),
59+
resolution=0.50,
60+
units="mpp",
61+
coord_space="resolution",
62+
)
63+
64+
model, weights_path = _load_mapde(name="mapde-conic")
65+
patch = model.preproc(patch)
66+
batch = torch.from_numpy(patch)[None]
67+
output = model.infer_batch(model, batch, device=select_device(on_gpu=ON_GPU))
68+
block_info = {
69+
0: {
70+
"array-location": [
71+
[0, 1],
72+
[0, 1],
73+
], # dummy block to test no valid detections
74+
}
75+
}
76+
output = model.postproc(output[0], block_info=block_info)
77+
xs, ys, _, _ = NucleusDetector._centroid_maps_to_detection_records(output, None)
78+
np.testing.assert_array_equal(xs, np.array([]))
79+
np.testing.assert_array_equal(ys, np.array([]))
80+
5681
Path(weights_path).unlink()
5782

5883

tests/models/test_arch_sccnn.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,31 @@ def test_functionality(remote_sample: Callable) -> None:
6060
batch,
6161
device=select_device(on_gpu=env_detection.has_gpu()),
6262
)
63-
output = model.postproc(output[0])
63+
block_info = {
64+
0: {
65+
"array-location": [[0, 31], [0, 31]],
66+
}
67+
}
68+
output = model.postproc(output[0], block_info=block_info)
6469
xs, ys, _, _ = NucleusDetector._centroid_maps_to_detection_records(output, None)
6570
np.testing.assert_array_equal(xs, np.array([7]))
6671
np.testing.assert_array_equal(ys, np.array([8]))
72+
73+
model = _load_sccnn(name="sccnn-conic")
74+
output = model.infer_batch(
75+
model,
76+
batch,
77+
device=select_device(on_gpu=env_detection.has_gpu()),
78+
)
79+
block_info = {
80+
0: {
81+
"array-location": [
82+
[0, 1],
83+
[0, 1],
84+
], # dummy block to test no valid detections
85+
}
86+
}
87+
output = model.postproc(output[0], block_info=block_info)
88+
xs, ys, _, _ = NucleusDetector._centroid_maps_to_detection_records(output, None)
89+
np.testing.assert_array_equal(xs, np.array([]))
90+
np.testing.assert_array_equal(ys, np.array([]))

0 commit comments

Comments
 (0)