1212from tiatoolbox .annotation .storage import SQLiteStore
1313from tiatoolbox .models .engine .nucleus_detector import (
1414 NucleusDetector ,
15- _flatten_predictions_to_dask ,
15+ # _flatten_predictions_to_dask,
1616)
1717from tiatoolbox .utils import env_detection as toolbox_env
1818from tiatoolbox .utils .misc import imwrite
@@ -54,7 +54,7 @@ def test_nucleus_detector_wsi(remote_sample: Callable, tmp_path: pathlib.Path) -
5454 store = SQLiteStore .open (save_dir / "wsi4_512_512.db" )
5555 assert 255 <= len (store .values ()) <= 265
5656 annotation = next (iter (store .values ()))
57- assert annotation .properties ["type " ] == "test_nucleus"
57+ assert annotation .properties ["class " ] == "test_nucleus"
5858 store .close ()
5959
6060 nucleus_detector .drop_keys = ["probs" ]
@@ -68,17 +68,18 @@ def test_nucleus_detector_wsi(remote_sample: Callable, tmp_path: pathlib.Path) -
6868 overwrite = True ,
6969 batch_size = 8 ,
7070 )
71+ print ("Result path:" , result_path )
7172
7273 zarr_path = result_path [mini_wsi_svs ]
7374 zarr_group = zarr .open (zarr_path , mode = "r" )
7475 xs = zarr_group ["x" ][:]
7576 ys = zarr_group ["y" ][:]
76- types = zarr_group ["types " ][:]
77+ classes = zarr_group ["classes " ][:]
7778 probs = zarr_group .get ("probs" , None )
7879 assert probs is None
7980 assert 255 <= len (xs ) <= 265
8081 assert 255 <= len (ys ) <= 265
81- assert 255 <= len (types ) <= 265
82+ assert 255 <= len (classes ) <= 265
8283
8384 _rm_dir (save_dir )
8485 pathlib .Path .unlink (mini_wsi_svs )
@@ -174,17 +175,16 @@ def test_nucleus_detector_patches_dict_output(
174175 save_dir = None ,
175176 class_dict = None ,
176177 )
177- output_dict = output_dict ["predictions" ]
178178 assert len (output_dict ["x" ]) == 2
179179 assert len (output_dict ["y" ]) == 2
180- assert len (output_dict ["types " ]) == 2
180+ assert len (output_dict ["classes " ]) == 2
181181 assert len (output_dict ["probs" ]) == 2
182182 assert len (output_dict ["x" ][0 ]) == 1
183183 assert len (output_dict ["x" ][1 ]) == 0
184184 assert len (output_dict ["y" ][0 ]) == 1
185185 assert len (output_dict ["y" ][1 ]) == 0
186- assert len (output_dict ["types " ][0 ]) == 1
187- assert len (output_dict ["types " ][1 ]) == 0
186+ assert len (output_dict ["classes " ][0 ]) == 1
187+ assert len (output_dict ["classes " ][1 ]) == 0
188188 assert len (output_dict ["probs" ][0 ]) == 1
189189 assert len (output_dict ["probs" ][1 ]) == 0
190190
@@ -220,40 +220,16 @@ def test_nucleus_detector_patches_zarr_output(
220220 overwrite = True ,
221221 )
222222
223- zarr_group = zarr .open (output_path , mode = "r" )
224- output_dict = {
225- "x" : zarr_group ["x" ][:],
226- "y" : zarr_group ["y" ][:],
227- "types" : zarr_group ["types" ][:],
228- "probs" : zarr_group ["probs" ][:],
229- "patch_offsets" : zarr_group ["patch_offsets" ][:],
230- }
231-
232- assert len (output_dict ["x" ]) == 1
233- assert len (output_dict ["y" ]) == 1
234- assert len (output_dict ["types" ]) == 1
235- assert len (output_dict ["probs" ]) == 1
236- assert len (output_dict ["patch_offsets" ]) == 3
237-
238- patch_1_start , patch_1_end = (
239- output_dict ["patch_offsets" ][0 ],
240- output_dict ["patch_offsets" ][1 ],
241- )
242- patch_2_start , patch_2_end = (
243- output_dict ["patch_offsets" ][1 ],
244- output_dict ["patch_offsets" ][2 ],
245- )
246- assert len (output_dict ["x" ][patch_1_start :patch_1_end ]) == 1
247- assert len (output_dict ["x" ][patch_2_start :patch_2_end ]) == 0
248-
249- assert len (output_dict ["y" ][patch_1_start :patch_1_end ]) == 1
250- assert len (output_dict ["y" ][patch_2_start :patch_2_end ]) == 0
223+ output_zarr = zarr .open (output_path , mode = "r" )
251224
252- assert len (output_dict ["types" ][patch_1_start :patch_1_end ]) == 1
253- assert len (output_dict ["types" ][patch_2_start :patch_2_end ]) == 0
254-
255- assert len (output_dict ["probs" ][patch_1_start :patch_1_end ]) == 1
256- assert len (output_dict ["probs" ][patch_2_start :patch_2_end ]) == 0
225+ assert output_zarr ["x" ][0 ].size == 1
226+ assert output_zarr ["x" ][1 ].size == 0
227+ assert output_zarr ["y" ][0 ].size == 1
228+ assert output_zarr ["y" ][1 ].size == 0
229+ assert output_zarr ["classes" ][0 ].size == 1
230+ assert output_zarr ["classes" ][1 ].size == 0
231+ assert output_zarr ["probs" ][0 ].size == 1
232+ assert output_zarr ["probs" ][1 ].size == 0
257233
258234 _rm_dir (save_dir )
259235
@@ -269,12 +245,12 @@ def test_centroid_maps_to_detection_arrays() -> None:
269245
270246 xs = detections ["x" ]
271247 ys = detections ["y" ]
272- types = detections ["types " ]
248+ classes = detections ["classes " ]
273249 probs = detections ["probs" ]
274250
275251 np .testing .assert_array_equal (xs , np .array ([1 , 3 ], dtype = np .uint32 ))
276252 np .testing .assert_array_equal (ys , np .array ([1 , 2 ], dtype = np .uint32 ))
277- np .testing .assert_array_equal (types , np .array ([0 , 1 ], dtype = np .uint32 ))
253+ np .testing .assert_array_equal (classes , np .array ([0 , 1 ], dtype = np .uint32 ))
278254 np .testing .assert_array_equal (probs , np .array ([1.0 , 0.5 ], dtype = np .float32 ))
279255
280256
@@ -283,65 +259,36 @@ def test_write_detection_arrays_to_store() -> None:
283259 detection_arrays = {
284260 "x" : np .array ([1 , 3 ], dtype = np .uint32 ),
285261 "y" : np .array ([1 , 2 ], dtype = np .uint32 ),
286- "types " : np .array ([0 , 1 ], dtype = np .uint32 ),
262+ "classes " : np .array ([0 , 1 ], dtype = np .uint32 ),
287263 "probs" : np .array ([1.0 , 0.5 ], dtype = np .float32 ),
288264 }
289265
290- store = NucleusDetector .write_detection_arrays_to_store (detection_arrays )
266+ store = NucleusDetector .save_detection_arrays_to_store (detection_arrays )
291267 assert len (store .values ()) == 2
292268
293269 detection_arrays = {
294270 "x" : np .array ([1 ], dtype = np .uint32 ),
295271 "y" : np .array ([1 , 2 ], dtype = np .uint32 ),
296- "types " : np .array ([0 ], dtype = np .uint32 ),
272+ "classes " : np .array ([0 ], dtype = np .uint32 ),
297273 "probs" : np .array ([1.0 , 0.5 ], dtype = np .float32 ),
298274 }
299275 with pytest .raises (
300276 ValueError ,
301277 match = r"Detection record lengths are misaligned." ,
302278 ):
303- _ = NucleusDetector .write_detection_arrays_to_store (detection_arrays )
279+ _ = NucleusDetector .save_detection_arrays_to_store (detection_arrays )
304280
305281
306282def test_write_detection_records_to_store_no_class_dict () -> None :
307283 """Test writing detection records to annotation store."""
308284 detection_records = (np .array ([1 ]), np .array ([2 ]), np .array ([0 ]), np .array ([1.0 ]))
309285
310286 dummy_store = SQLiteStore ()
311- total = NucleusDetector ._write_detection_records_to_store (
287+ total = NucleusDetector ._write_detection_arrays_to_store (
312288 detection_records , store = dummy_store , scale_factor = (1.0 , 1.0 ), class_dict = None
313289 )
314290 assert len (dummy_store .values ()) == 1
315291 assert total == 1
316292 annotation = next (iter (dummy_store .values ()))
317- assert annotation .properties ["type " ] == 0
293+ assert annotation .properties ["class " ] == 0
318294 dummy_store .close ()
319-
320-
321- def test_flatten_predictions_to_dask () -> None :
322- """Test flattening ragged predictions to Dask array."""
323- ragged_obj_array = np .empty (3 , dtype = object )
324- ragged_obj_array [0 ] = np .array ([1.0 , 0.0 ], dtype = np .float32 )
325- ragged_obj_array [1 ] = np .array ([0.5 , 0.5 ], dtype = np .float32 )
326- ragged_obj_array [2 ] = np .array ([0.2 , 0.8 , 0.8 , 0.2 ], dtype = np .float32 )
327-
328- ragged_da_array = da .from_array (ragged_obj_array , chunks = (len (ragged_obj_array ),))
329-
330- flat_dask_array = _flatten_predictions_to_dask (ragged_da_array )
331- expected_array = np .array (
332- [
333- 1.0 ,
334- 0.0 ,
335- 0.5 ,
336- 0.5 ,
337- 0.2 ,
338- 0.8 ,
339- 0.8 ,
340- 0.2 ,
341- ],
342- dtype = np .float32 ,
343- )
344- np .testing .assert_array_equal (flat_dask_array .compute (), expected_array )
345-
346- flat_dask_array = _flatten_predictions_to_dask (ragged_obj_array )
347- np .testing .assert_array_equal (flat_dask_array .compute (), expected_array )
0 commit comments