66
77import dask .array as da
88import numpy as np
9- import pandas as pd
10- import pytest
9+ import zarr
1110
1211from tiatoolbox .annotation .storage import SQLiteStore
1312from tiatoolbox .models .engine .nucleus_detector import NucleusDetector
@@ -28,91 +27,6 @@ def check_output(path: pathlib.Path) -> None:
2827 """Check NucleusDetector output."""
2928
3029
31- def test_nucleus_detection_nms_empty_dataframe () -> None :
32- """nucleus_detection_nms should return a copy for empty inputs."""
33- df = pd .DataFrame (columns = ["x" , "y" , "type" , "prob" ])
34-
35- result = NucleusDetector .nucleus_detection_nms (df , radius = 3 )
36-
37- assert result .empty
38- assert result is not df
39- assert list (result .columns ) == ["x" , "y" , "type" , "prob" ]
40-
41-
42- def test_nucleus_detection_nms_invalid_radius () -> None :
43- """Radius must be strictly positive."""
44- df = pd .DataFrame ({"x" : [0 ], "y" : [0 ], "type" : [1 ], "prob" : [0.9 ]})
45-
46- with pytest .raises (ValueError , match = "radius must be > 0" ):
47- NucleusDetector .nucleus_detection_nms (df , radius = 0 )
48-
49-
50- def test_nucleus_detection_nms_invalid_overlap_threshold () -> None :
51- """overlap_threshold must lie in (0, 1]."""
52- df = pd .DataFrame ({"x" : [0 ], "y" : [0 ], "type" : [1 ], "prob" : [0.9 ]})
53-
54- message = r"overlap_threshold must be in \(0\.0, 1\.0\], got 0"
55- with pytest .raises (ValueError , match = message ):
56- NucleusDetector .nucleus_detection_nms (df , radius = 1 , overlap_threshold = 0 )
57-
58-
59- def test_nucleus_detection_nms_suppresses_overlapping_detections () -> None :
60- """Lower-probability overlapping detections are removed."""
61- df = pd .DataFrame (
62- {
63- "x" : [2 , 0 , 20 ],
64- "y" : [1 , 0 , 20 ],
65- "type" : [1 , 1 , 2 ],
66- "prob" : [0.6 , 0.9 , 0.7 ],
67- }
68- )
69-
70- result = NucleusDetector .nucleus_detection_nms (df , radius = 5 )
71-
72- expected = pd .DataFrame (
73- {"x" : [0 , 20 ], "y" : [0 , 20 ], "type" : [1 , 2 ], "prob" : [0.9 , 0.7 ]}
74- )
75- pd .testing .assert_frame_equal (result .reset_index (drop = True ), expected )
76-
77-
78- def test_nucleus_detection_nms_suppresses_across_types () -> None :
79- """Overlapping detections of different types are also suppressed."""
80- df = pd .DataFrame (
81- {
82- "x" : [0 , 0 , 20 ],
83- "y" : [0 , 0 , 0 ],
84- "type" : [1 , 2 , 1 ],
85- "prob" : [0.6 , 0.95 , 0.4 ],
86- }
87- )
88-
89- result = NucleusDetector .nucleus_detection_nms (df , radius = 5 )
90-
91- expected = pd .DataFrame (
92- {"x" : [0 , 20 ], "y" : [0 , 0 ], "type" : [2 , 1 ], "prob" : [0.95 , 0.4 ]}
93- )
94- pd .testing .assert_frame_equal (result .reset_index (drop = True ), expected )
95-
96-
97- def test_nucleus_detection_nms_retains_non_overlapping_candidates () -> None :
98- """Detections with IoU below the threshold are preserved."""
99- df = pd .DataFrame (
100- {
101- "x" : [0 , 10 ],
102- "y" : [0 , 0 ],
103- "type" : [1 , 1 ],
104- "prob" : [0.8 , 0.5 ],
105- }
106- )
107-
108- result = NucleusDetector .nucleus_detection_nms (df , radius = 5 , overlap_threshold = 0.5 )
109-
110- expected = pd .DataFrame (
111- {"x" : [0 , 10 ], "y" : [0 , 0 ], "type" : [1 , 1 ], "prob" : [0.8 , 0.5 ]}
112- )
113- pd .testing .assert_frame_equal (result .reset_index (drop = True ), expected )
114-
115-
11630def test_nucleus_detector_wsi (remote_sample : Callable , tmp_path : pathlib .Path ) -> None :
11731 """Test for nucleus detection engine."""
11832 mini_wsi_svs = pathlib .Path (remote_sample ("wsi4_512_512_svs" ))
@@ -136,10 +50,31 @@ def test_nucleus_detector_wsi(remote_sample: Callable, tmp_path: pathlib.Path) -
13650 assert len (store .values ()) == 281
13751 store .close ()
13852
53+ result_path = nucleus_detector .run (
54+ patch_mode = False ,
55+ device = device ,
56+ output_type = "zarr" ,
57+ memory_threshold = 50 ,
58+ images = [mini_wsi_svs ],
59+ save_dir = save_dir ,
60+ overwrite = True ,
61+ )
62+
63+ zarr_path = result_path [mini_wsi_svs ]
64+ zarr_group = zarr .open (zarr_path , mode = "r" )
65+ xs = zarr_group ["x" ][:]
66+ ys = zarr_group ["y" ][:]
67+ types = zarr_group ["types" ][:]
68+ probs = zarr_group ["probs" ][:]
69+ assert len (xs ) == 281
70+ assert len (ys ) == 281
71+ assert len (types ) == 281
72+ assert len (probs ) == 281
73+
13974 _rm_dir (save_dir )
14075
14176
142- def test_nucleus_detector_patch (
77+ def test_nucleus_detector_patch_annotation_store_output (
14378 remote_sample : Callable , tmp_path : pathlib .Path
14479) -> None :
14580 """Test for nucleus detection engine in patch mode."""
@@ -183,7 +118,7 @@ def test_nucleus_detector_patch(
183118 _ = nucleus_detector .run (
184119 patch_mode = True ,
185120 device = device ,
186- output_type = "zarr " ,
121+ output_type = "annotationstore " ,
187122 memory_threshold = 50 ,
188123 images = [save_dir / "patch_0.png" , save_dir / "patch_1.png" ],
189124 save_dir = save_dir ,
@@ -201,30 +136,63 @@ def test_nucleus_detector_patch(
201136 _rm_dir (save_dir )
202137
203138
204- def test_nucleus_detector_write_centroid_maps (tmp_path : pathlib .Path ) -> None :
205- """Test for _write_centroid_maps function."""
206- detection_maps = np .zeros ((20 , 20 , 1 ), dtype = np .uint8 )
207- detection_maps = da .from_array (detection_maps , chunks = (20 , 20 , 1 ))
139+ def test_nucleus_detector_patches_dict_output (
140+ remote_sample : Callable ,
141+ ) -> None :
142+ """Test for nucleus detection engine in patch mode."""
143+ mini_wsi_svs = pathlib .Path (remote_sample ("wsi4_512_512_svs" ))
208144
209- store = NucleusDetector .write_centroid_maps_to_store (
210- detection_maps = detection_maps , class_dict = None
211- )
212- assert len (store .values ()) == 0
213- store .close ()
145+ wsi_reader = WSIReader .open (mini_wsi_svs )
146+ patch_1 = wsi_reader .read_rect ((0 , 0 ), (252 , 252 ), resolution = 0.5 , units = "mpp" )
147+ patch_2 = wsi_reader .read_rect ((252 , 252 ), (252 , 252 ), resolution = 0.5 , units = "mpp" )
148+ patch_3 = np .zeros ((252 , 252 , 3 ), dtype = np .uint8 )
149+
150+ pretrained_model = "mapde-conic"
151+
152+ nucleus_detector = NucleusDetector (model = pretrained_model )
214153
215- detection_maps = np .zeros ((20 , 20 , 1 ), dtype = np .uint8 )
216- detection_maps [10 , 10 , 0 ] = 1
217- detection_maps = da .from_array (detection_maps , chunks = (20 , 20 , 1 ))
218- _ = NucleusDetector .write_centroid_maps_to_store (
219- detection_maps = detection_maps ,
220- save_path = tmp_path / "test.db" ,
221- class_dict = {0 : "nucleus" },
154+ output_dict = nucleus_detector .run (
155+ patch_mode = True ,
156+ device = device ,
157+ output_type = "dict" ,
158+ memory_threshold = 50 ,
159+ images = [patch_1 , patch_2 , patch_3 ],
160+ save_dir = None ,
161+ class_dict = None ,
222162 )
223- store = SQLiteStore .open (tmp_path / "test.db" )
224- assert len (store .values ()) == 1
225- annotation = next (iter (store .values ()))
226- print (annotation )
227- assert annotation .properties ["type" ] == "nucleus"
228- assert annotation .geometry .centroid .x == 10.0
229- assert annotation .geometry .centroid .y == 10.0
230- store .close ()
163+ assert len (output_dict ["x" ]) == 3
164+ assert len (output_dict ["y" ]) == 3
165+ assert len (output_dict ["types" ]) == 3
166+ assert len (output_dict ["probs" ]) == 3
167+ assert len (output_dict ["x" ][0 ]) == 270
168+ assert len (output_dict ["x" ][1 ]) == 52
169+ assert len (output_dict ["x" ][2 ]) == 0
170+ assert len (output_dict ["y" ][0 ]) == 270
171+ assert len (output_dict ["y" ][1 ]) == 52
172+ assert len (output_dict ["y" ][2 ]) == 0
173+ assert len (output_dict ["types" ][0 ]) == 270
174+ assert len (output_dict ["types" ][1 ]) == 52
175+ assert len (output_dict ["types" ][2 ]) == 0
176+ assert len (output_dict ["probs" ][0 ]) == 270
177+ assert len (output_dict ["probs" ][1 ]) == 52
178+ assert len (output_dict ["probs" ][2 ]) == 0
179+
180+
181+ def test_centroid_maps_to_detection_arrays () -> None :
182+ """Convert centroid maps to detection arrays."""
183+ detection_maps = np .zeros ((4 , 4 , 2 ), dtype = np .float32 )
184+ detection_maps [1 , 1 , 0 ] = 1.0
185+ detection_maps [2 , 3 , 1 ] = 0.5
186+ detection_maps = da .from_array (detection_maps , chunks = (2 , 2 , 2 ))
187+
188+ detections = NucleusDetector ._centroid_maps_to_detection_arrays (detection_maps )
189+
190+ xs = detections ["x" ]
191+ ys = detections ["y" ]
192+ types = detections ["types" ]
193+ probs = detections ["probs" ]
194+
195+ np .testing .assert_array_equal (xs , np .array ([1 , 3 ], dtype = np .uint32 ))
196+ np .testing .assert_array_equal (ys , np .array ([1 , 2 ], dtype = np .uint32 ))
197+ np .testing .assert_array_equal (types , np .array ([0 , 1 ], dtype = np .uint32 ))
198+ np .testing .assert_array_equal (probs , np .array ([1.0 , 0.5 ], dtype = np .float32 ))
0 commit comments