Skip to content

Commit 0f8d49e

Browse files
committed
✅ Add checks for annotationstore output
1 parent 8321e0e commit 0f8d49e

File tree

1 file changed

+51
-30
lines changed

1 file changed

+51
-30
lines changed

tests/engines/test_nucleus_instance_segmentor.py

Lines changed: 51 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,14 @@
88
import torch
99
import zarr
1010

11+
from tiatoolbox.annotation.storage import SQLiteStore
1112
from tiatoolbox.models import NucleusInstanceSegmentor
1213
from tiatoolbox.wsicore import WSIReader
1314

1415
device = "cuda:0" if torch.cuda.is_available() else "cpu"
1516

1617

17-
def test_functionality_patch_mode(
18+
def test_functionality_patch_mode( # noqa: PLR0915
1819
remote_sample: Callable, track_tmp_path: Path
1920
) -> None:
2021
"""Patch mode functionality test for nuclei instance segmentor."""
@@ -138,43 +139,63 @@ def test_functionality_patch_mode(
138139
for a, b in zip(output["type"][1], output_["type"][1], strict=False)
139140
)
140141

141-
142-
def test_functionality_patch_mode_anns(
143-
remote_sample: Callable, track_tmp_path: Path
144-
) -> None:
145-
"""Patch mode functionality test for nuclei instance segmentor."""
146-
mini_wsi_svs = Path(remote_sample("wsi4_1k_1k_svs"))
147-
mini_wsi = WSIReader.open(mini_wsi_svs)
148-
size = (256, 256)
149-
resolution = 0.25
150-
units: Final = "mpp"
151-
patch1 = mini_wsi.read_rect(
152-
location=(0, 0),
153-
size=size,
154-
resolution=resolution,
155-
units=units,
156-
)
157-
patch2 = mini_wsi.read_rect(
158-
location=(512, 512),
159-
size=size,
160-
resolution=resolution,
161-
units=units,
162-
)
163-
164-
# Test dummy input, should result in no output segmentation
165-
patch3 = np.zeros_like(patch1)
166-
167-
patches = np.stack(arrays=[patch1, patch2, patch3], axis=0)
168-
169142
inst_segmentor = NucleusInstanceSegmentor(
170143
batch_size=1,
171144
num_workers=0,
172145
model="hovernet_fast-pannuke",
173146
)
174-
_ = inst_segmentor.run(
147+
output = inst_segmentor.run(
175148
images=patches,
176149
patch_mode=True,
177150
device=device,
178151
output_type="annotationstore",
179152
save_dir=track_tmp_path / "patch_output_annotationstore",
180153
)
154+
155+
assert output[0] == track_tmp_path / "patch_output_annotationstore" / "0.db"
156+
assert len(output) == 3
157+
store_ = SQLiteStore.open(output[0])
158+
annotations_ = store_.values()
159+
annotations_geometry_type = [
160+
str(annotation_.geometry_type) for annotation_ in annotations_
161+
]
162+
assert "Polygon" in annotations_geometry_type
163+
164+
annotations_list = list(annotations_)
165+
ann_properties = [ann.properties for ann in annotations_list]
166+
167+
result = {}
168+
for d in ann_properties:
169+
for key, value in d.items():
170+
result.setdefault(key, []).append(value)
171+
172+
polygons = [ann.geometry for ann in annotations_list]
173+
result["contour"] = [list(poly.exterior.coords) for poly in polygons]
174+
175+
assert all(
176+
np.array_equal(a, b)
177+
for a, b in zip(result["box"], output_["box"][0], strict=False)
178+
)
179+
180+
assert all(
181+
np.array_equal(a, b)
182+
for a, b in zip(result["centroid"], output_["centroid"][0], strict=False)
183+
)
184+
185+
assert all(
186+
np.array_equal(a, b)
187+
for a, b in zip(result["prob"], output_["prob"][0], strict=False)
188+
)
189+
190+
assert all(
191+
np.array_equal(a, b)
192+
for a, b in zip(result["type"], output_["type"][0], strict=False)
193+
)
194+
195+
assert all(
196+
np.array_equal(
197+
np.array(a[:-1], dtype=int), # discard last point
198+
np.array(b, dtype=int),
199+
)
200+
for a, b in zip(result["contour"], output_["contour"][0], strict=False)
201+
)

0 commit comments

Comments
 (0)