11"""Test tiatoolbox.models.engine.nucleus_instance_segmentor."""
22
3- from collections .abc import Callable
3+ from collections .abc import Callable , Sequence
44from pathlib import Path
5- from typing import Final
5+ from typing import Any , Final
66
77import numpy as np
88import torch
1313from tiatoolbox .wsicore import WSIReader
1414
1515device = "cuda:0" if torch .cuda .is_available () else "cpu"
16+ OutputType = dict [str , Any ] | Any
1617
1718
18- def test_functionality_patch_mode ( # noqa: PLR0915
19+ def assert_output_lengths (output : OutputType , expected_counts : Sequence [int ]) -> None :
20+ """Assert lengths of output dict fields against expected counts."""
21+ for field in ["box" , "centroid" , "contour" , "prob" , "type" ]:
22+ for i , expected in enumerate (expected_counts ):
23+ assert len (output [field ][i ]) == expected , f"{ field } [{ i } ] mismatch"
24+
25+
26+ def assert_output_equal (
27+ output_a : OutputType ,
28+ output_b : OutputType ,
29+ fields : Sequence [str ],
30+ indices_a : Sequence [int ],
31+ indices_b : Sequence [int ],
32+ ) -> None :
33+ """Assert equality of arrays across outputs for given fields/indices."""
34+ for field in fields :
35+ for i_a , i_b in zip (indices_a , indices_b , strict = False ):
36+ left = output_a [field ][i_a ]
37+ right = output_b [field ][i_b ]
38+ assert all (
39+ np .array_equal (a , b ) for a , b in zip (left , right , strict = False )
40+ ), f"{ field } [{ i_a } ] vs { field } [{ i_b } ] mismatch"
41+
42+
43+ def assert_predictions_and_boxes (
44+ output : OutputType , expected_counts : Sequence [int ], * , is_zarr : bool = False
45+ ) -> None :
46+ """Assert predictions maxima and box lengths against expected counts."""
47+ # predictions maxima
48+ for idx , expected in enumerate (expected_counts ):
49+ if is_zarr and idx == 2 :
50+ # zarr output doesn't store predictions for patch 2
51+ continue
52+ assert np .max (output ["predictions" ][idx ][:]) == expected , (
53+ f"predictions[{ idx } ] mismatch"
54+ )
55+
56+ # box lengths
57+ for idx , expected in enumerate (expected_counts ):
58+ if is_zarr and idx < 2 :
59+ # for zarr, compare boxes only for patches 0 and 1
60+ continue
61+ assert len (output ["box" ][idx ]) == expected , f"box[{ idx } ] mismatch"
62+
63+
64+ def test_functionality_patch_mode (
1965 remote_sample : Callable , track_tmp_path : Path
2066) -> None :
2167 """Patch mode functionality test for nuclei instance segmentor."""
@@ -24,178 +70,105 @@ def test_functionality_patch_mode( # noqa: PLR0915
2470 size = (256 , 256 )
2571 resolution = 0.25
2672 units : Final = "mpp"
73+
2774 patch1 = mini_wsi .read_rect (
28- location = (0 , 0 ),
29- size = size ,
30- resolution = resolution ,
31- units = units ,
75+ location = (0 , 0 ), size = size , resolution = resolution , units = units
3276 )
3377 patch2 = mini_wsi .read_rect (
34- location = (512 , 512 ),
35- size = size ,
36- resolution = resolution ,
37- units = units ,
78+ location = (512 , 512 ), size = size , resolution = resolution , units = units
3879 )
39-
40- # Test dummy input, should result in no output segmentation
4180 patch3 = np .zeros_like (patch1 )
42-
43- patches = np .stack (arrays = [patch1 , patch2 , patch3 ], axis = 0 )
81+ patches = np .stack ([patch1 , patch2 , patch3 ], axis = 0 )
4482
4583 inst_segmentor = NucleusInstanceSegmentor (
46- batch_size = 1 ,
47- num_workers = 0 ,
48- model = "hovernet_fast-pannuke" ,
84+ batch_size = 1 , num_workers = 0 , model = "hovernet_fast-pannuke"
4985 )
50- output = inst_segmentor .run (
51- images = patches ,
52- patch_mode = True ,
53- device = device ,
54- output_type = "dict" ,
86+ output_dict = inst_segmentor .run (
87+ images = patches , patch_mode = True , device = device , output_type = "dict"
5588 )
5689
57- assert np .max (output ["predictions" ][0 ][:]) == 41
58- assert np .max (output ["predictions" ][1 ][:]) == 17
59- assert np .max (output ["predictions" ][2 ][:]) == 0
60-
61- assert len (output ["box" ][0 ]) == 41
62- assert len (output ["box" ][1 ]) == 17
63- assert len (output ["box" ][2 ]) == 0
64-
65- assert len (output ["centroid" ][0 ]) == 41
66- assert len (output ["centroid" ][1 ]) == 17
67- assert len (output ["centroid" ][2 ]) == 0
68-
69- assert len (output ["contour" ][0 ]) == 41
70- assert len (output ["contour" ][1 ]) == 17
71- assert len (output ["contour" ][2 ]) == 0
72-
73- assert len (output ["prob" ][0 ]) == 41
74- assert len (output ["prob" ][1 ]) == 17
75- assert len (output ["prob" ][2 ]) == 0
90+ expected_counts = [41 , 17 , 0 ]
7691
77- assert len (output ["type" ][0 ]) == 41
78- assert len (output ["type" ][1 ]) == 17
79- assert len (output ["type" ][2 ]) == 0
92+ assert_predictions_and_boxes (output_dict , expected_counts , is_zarr = False )
93+ assert_output_lengths (output_dict , expected_counts )
8094
81- output_ = output
82-
83- output = inst_segmentor .run (
95+ # Zarr output comparison
96+ output_zarr = inst_segmentor .run (
8497 images = patches ,
8598 patch_mode = True ,
8699 device = device ,
87100 output_type = "zarr" ,
88101 save_dir = track_tmp_path / "patch_output_zarr" ,
89102 )
103+ output_zarr = zarr .open (output_zarr , mode = "r" )
104+ assert_predictions_and_boxes (output_zarr , expected_counts , is_zarr = True )
90105
91- output = zarr .open (output , mode = "r" )
92-
93- assert np .max (output ["predictions" ][0 ][:]) == 41
94- assert np .max (output ["predictions" ][1 ][:]) == 17
95-
96- assert all (
97- np .array_equal (a , b )
98- for a , b in zip (output ["box" ][0 ], output_ ["box" ][0 ], strict = False )
106+ assert_output_equal (
107+ output_zarr ,
108+ output_dict ,
109+ fields = ["box" , "centroid" , "contour" , "prob" , "type" ],
110+ indices_a = [0 , 1 , 2 ],
111+ indices_b = [0 , 1 , 2 ],
99112 )
100- assert all (
101- np .array_equal (a , b )
102- for a , b in zip (output ["box" ][1 ], output_ ["box" ][1 ], strict = False )
103- )
104- assert len (output ["box" ][2 ]) == 0
105113
106- assert all (
107- np .array_equal (a , b )
108- for a , b in zip (output ["centroid" ][0 ], output_ ["centroid" ][0 ], strict = False )
109- )
110- assert all (
111- np .array_equal (a , b )
112- for a , b in zip (output ["centroid" ][1 ], output_ ["centroid" ][1 ], strict = False )
113- )
114-
115- assert all (
116- np .array_equal (a , b )
117- for a , b in zip (output ["contour" ][0 ], output_ ["contour" ][0 ], strict = False )
118- )
119- assert all (
120- np .array_equal (a , b )
121- for a , b in zip (output ["contour" ][1 ], output_ ["contour" ][1 ], strict = False )
122- )
123-
124- assert all (
125- np .array_equal (a , b )
126- for a , b in zip (output ["prob" ][0 ], output_ ["prob" ][0 ], strict = False )
127- )
128- assert all (
129- np .array_equal (a , b )
130- for a , b in zip (output ["prob" ][1 ], output_ ["prob" ][1 ], strict = False )
131- )
132-
133- assert all (
134- np .array_equal (a , b )
135- for a , b in zip (output ["type" ][0 ], output_ ["type" ][0 ], strict = False )
136- )
137- assert all (
138- np .array_equal (a , b )
139- for a , b in zip (output ["type" ][1 ], output_ ["type" ][1 ], strict = False )
140- )
141-
142- inst_segmentor = NucleusInstanceSegmentor (
143- batch_size = 1 ,
144- num_workers = 0 ,
145- model = "hovernet_fast-pannuke" ,
146- )
147- output = inst_segmentor .run (
114+ # AnnotationStore output comparison
115+ output_ann = inst_segmentor .run (
148116 images = patches ,
149117 patch_mode = True ,
150118 device = device ,
151119 output_type = "annotationstore" ,
152120 save_dir = track_tmp_path / "patch_output_annotationstore" ,
153121 )
122+ assert len (output_ann ) == 3
123+ assert output_ann [0 ] == track_tmp_path / "patch_output_annotationstore" / "0.db"
154124
155- assert output [0 ] == track_tmp_path / "patch_output_annotationstore" / "0.db"
156- assert len (output ) == 3
157- store_ = SQLiteStore .open (output [0 ])
158- annotations_ = store_ .values ()
159- annotations_geometry_type = [
160- str (annotation_ .geometry_type ) for annotation_ in annotations_
161- ]
162- assert "Polygon" in annotations_geometry_type
163-
164- annotations_list = list (annotations_ )
165- ann_properties = [ann .properties for ann in annotations_list ]
166-
167- result = {}
168- for d in ann_properties :
169- for key , value in d .items ():
170- result .setdefault (key , []).append (value )
171-
172- polygons = [ann .geometry for ann in annotations_list ]
173- result ["contour" ] = [list (poly .exterior .coords ) for poly in polygons ]
174-
175- assert all (
176- np .array_equal (a , b )
177- for a , b in zip (result ["box" ], output_ ["box" ][0 ], strict = False )
178- )
179-
180- assert all (
181- np .array_equal (a , b )
182- for a , b in zip (result ["centroid" ], output_ ["centroid" ][0 ], strict = False )
183- )
184-
185- assert all (
186- np .array_equal (a , b )
187- for a , b in zip (result ["prob" ], output_ ["prob" ][0 ], strict = False )
188- )
189-
190- assert all (
191- np .array_equal (a , b )
192- for a , b in zip (result ["type" ], output_ ["type" ][0 ], strict = False )
193- )
194-
195- assert all (
196- np .array_equal (
197- np .array (a [:- 1 ], dtype = int ), # discard last point
198- np .array (b , dtype = int ),
125+ for patch_idx , db_path in enumerate (output_ann ):
126+ assert (
127+ db_path
128+ == track_tmp_path / "patch_output_annotationstore" / f"{ patch_idx } .db"
199129 )
200- for a , b in zip (result ["contour" ], output_ ["contour" ][0 ], strict = False )
201- )
130+ store_ = SQLiteStore .open (db_path )
131+ annotations_ = store_ .values ()
132+ annotations_geometry_type = [
133+ str (annotation_ .geometry_type ) for annotation_ in annotations_
134+ ]
135+ annotations_list = list (annotations_ )
136+ if expected_counts [patch_idx ] > 0 :
137+ assert "Polygon" in annotations_geometry_type
138+
139+ # Build result dict from annotation properties
140+ result = {}
141+ for ann in annotations_list :
142+ for key , value in ann .properties .items ():
143+ result .setdefault (key , []).append (value )
144+ result ["contour" ] = [
145+ list (poly .exterior .coords )
146+ for poly in (a .geometry for a in annotations_list )
147+ ]
148+
149+ # wrap it to make it compatible to assert_output_lengths
150+ result_ = {
151+ field : [result [field ]]
152+ for field in ["box" , "centroid" , "contour" , "prob" , "type" ]
153+ }
154+
155+ # Lengths and equality checks for this patch
156+ assert_output_lengths (result_ , [expected_counts [patch_idx ]])
157+ assert_output_equal (
158+ result_ ,
159+ output_dict ,
160+ fields = ["box" , "centroid" , "prob" , "type" ],
161+ indices_a = [0 ],
162+ indices_b = [patch_idx ],
163+ )
164+
165+ # Contour check (discard last point)
166+ assert all (
167+ np .array_equal (np .array (a [:- 1 ], dtype = int ), np .array (b , dtype = int ))
168+ for a , b in zip (
169+ result ["contour" ], output_dict ["contour" ][patch_idx ], strict = False
170+ )
171+ )
172+ else :
173+ assert annotations_geometry_type == []
174+ assert annotations_list == []
0 commit comments