44import shutil
55from collections .abc import Callable
66
7+ import pandas as pd
8+ import pytest
9+
710from tiatoolbox .annotation .storage import SQLiteStore
811from tiatoolbox .models .engine .nucleus_detector import NucleusDetector
912from tiatoolbox .utils import env_detection as toolbox_env
13+ from tiatoolbox .wsicore .wsireader import WSIReader
1014
1115device = "cuda" if toolbox_env .has_gpu () else "cpu"
1216
@@ -19,8 +23,91 @@ def _rm_dir(path: pathlib.Path) -> None:
1923
2024def check_output (path : pathlib .Path ) -> None :
2125 """Check NucleusDetector output."""
22- store = SQLiteStore .open (path )
23- assert len (store .values ()) == 281
26+
27+
28+ def test_nucleus_detection_nms_empty_dataframe () -> None :
29+ """nucleus_detection_nms should return a copy for empty inputs."""
30+ df = pd .DataFrame (columns = ["x" , "y" , "type" , "prob" ])
31+
32+ result = NucleusDetector .nucleus_detection_nms (df , radius = 3 )
33+
34+ assert result .empty
35+ assert result is not df
36+ assert list (result .columns ) == ["x" , "y" , "type" , "prob" ]
37+
38+
39+ def test_nucleus_detection_nms_invalid_radius () -> None :
40+ """Radius must be strictly positive."""
41+ df = pd .DataFrame ({"x" : [0 ], "y" : [0 ], "type" : [1 ], "prob" : [0.9 ]})
42+
43+ with pytest .raises (ValueError , match = "radius must be > 0" ):
44+ NucleusDetector .nucleus_detection_nms (df , radius = 0 )
45+
46+
47+ def test_nucleus_detection_nms_invalid_overlap_threshold () -> None :
48+ """overlap_threshold must lie in (0, 1]."""
49+ df = pd .DataFrame ({"x" : [0 ], "y" : [0 ], "type" : [1 ], "prob" : [0.9 ]})
50+
51+ message = r"overlap_threshold must be in \(0\.0, 1\.0\], got 0"
52+ with pytest .raises (ValueError , match = message ):
53+ NucleusDetector .nucleus_detection_nms (df , radius = 1 , overlap_threshold = 0 )
54+
55+
56+ def test_nucleus_detection_nms_suppresses_overlapping_detections () -> None :
57+ """Lower-probability overlapping detections are removed."""
58+ df = pd .DataFrame (
59+ {
60+ "x" : [2 , 0 , 20 ],
61+ "y" : [1 , 0 , 20 ],
62+ "type" : [1 , 1 , 2 ],
63+ "prob" : [0.6 , 0.9 , 0.7 ],
64+ }
65+ )
66+
67+ result = NucleusDetector .nucleus_detection_nms (df , radius = 5 )
68+
69+ expected = pd .DataFrame (
70+ {"x" : [0 , 20 ], "y" : [0 , 20 ], "type" : [1 , 2 ], "prob" : [0.9 , 0.7 ]}
71+ )
72+ pd .testing .assert_frame_equal (result .reset_index (drop = True ), expected )
73+
74+
75+ def test_nucleus_detection_nms_suppresses_across_types () -> None :
76+ """Overlapping detections of different types are also suppressed."""
77+ df = pd .DataFrame (
78+ {
79+ "x" : [0 , 0 , 20 ],
80+ "y" : [0 , 0 , 0 ],
81+ "type" : [1 , 2 , 1 ],
82+ "prob" : [0.6 , 0.95 , 0.4 ],
83+ }
84+ )
85+
86+ result = NucleusDetector .nucleus_detection_nms (df , radius = 5 )
87+
88+ expected = pd .DataFrame (
89+ {"x" : [0 , 20 ], "y" : [0 , 0 ], "type" : [2 , 1 ], "prob" : [0.95 , 0.4 ]}
90+ )
91+ pd .testing .assert_frame_equal (result .reset_index (drop = True ), expected )
92+
93+
94+ def test_nucleus_detection_nms_retains_non_overlapping_candidates () -> None :
95+ """Detections with IoU below the threshold are preserved."""
96+ df = pd .DataFrame (
97+ {
98+ "x" : [0 , 10 ],
99+ "y" : [0 , 0 ],
100+ "type" : [1 , 1 ],
101+ "prob" : [0.8 , 0.5 ],
102+ }
103+ )
104+
105+ result = NucleusDetector .nucleus_detection_nms (df , radius = 5 , overlap_threshold = 0.5 )
106+
107+ expected = pd .DataFrame (
108+ {"x" : [0 , 10 ], "y" : [0 , 0 ], "type" : [1 , 1 ], "prob" : [0.8 , 0.5 ]}
109+ )
110+ pd .testing .assert_frame_equal (result .reset_index (drop = True ), expected )
24111
25112
26113def test_nucleus_detector_wsi (remote_sample : Callable , tmp_path : pathlib .Path ) -> None :
@@ -36,13 +123,50 @@ def test_nucleus_detector_wsi(remote_sample: Callable, tmp_path: pathlib.Path) -
36123 patch_mode = False ,
37124 device = device ,
38125 output_type = "annotationstore" ,
39- auto_get_mask = True ,
40126 memory_threshold = 50 ,
41127 images = [mini_wsi_svs ],
42128 save_dir = save_dir ,
43129 overwrite = True ,
44130 )
45131
46- check_output (save_dir / "wsi4_512_512.db" )
132+ store = SQLiteStore .open (save_dir / "wsi4_512_512.db" )
133+ assert len (store .values ()) == 281
134+ store .close ()
135+
136+ _rm_dir (save_dir )
137+
138+
139+ def test_nucleus_detector_patch (
140+ remote_sample : Callable , tmp_path : pathlib .Path
141+ ) -> None :
142+ """Test for nucleus detection engine in patch mode."""
143+ mini_wsi_svs = pathlib .Path (remote_sample ("wsi4_512_512_svs" ))
144+
145+ wsi_reader = WSIReader .open (mini_wsi_svs )
146+ patch_1 = wsi_reader .read_rect ((0 , 0 ), (252 , 252 ), resolution = 0.5 , units = "mpp" )
147+ patch_2 = wsi_reader .read_rect ((252 , 252 ), (252 , 252 ), resolution = 0.5 , units = "mpp" )
148+
149+ pretrained_model = "mapde-conic"
150+
151+ save_dir = tmp_path
152+
153+ nucleus_detector = NucleusDetector (model = pretrained_model )
154+ _ = nucleus_detector .run (
155+ patch_mode = True ,
156+ device = device ,
157+ output_type = "annotationstore" ,
158+ memory_threshold = 50 ,
159+ images = [patch_1 , patch_2 ],
160+ save_dir = save_dir ,
161+ overwrite = True ,
162+ )
163+
164+ store_1 = SQLiteStore .open (save_dir / "0.db" )
165+ assert len (store_1 .values ()) == 270
166+ store_1 .close ()
167+
168+ store_2 = SQLiteStore .open (save_dir / "1.db" )
169+ assert len (store_2 .values ()) == 52
170+ store_2 .close ()
47171
48172 _rm_dir (save_dir )
0 commit comments