Skip to content

Commit a911d3c

Browse files
committed
🎨 Improve structure of the test
1 parent 0f8d49e commit a911d3c

File tree

1 file changed

+122
-149
lines changed

1 file changed

+122
-149
lines changed
Lines changed: 122 additions & 149 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
"""Test tiatoolbox.models.engine.nucleus_instance_segmentor."""
22

3-
from collections.abc import Callable
3+
from collections.abc import Callable, Sequence
44
from pathlib import Path
5-
from typing import Final
5+
from typing import Any, Final
66

77
import numpy as np
88
import torch
@@ -13,9 +13,55 @@
1313
from tiatoolbox.wsicore import WSIReader
1414

1515
device = "cuda:0" if torch.cuda.is_available() else "cpu"
16+
OutputType = dict[str, Any] | Any
1617

1718

18-
def test_functionality_patch_mode( # noqa: PLR0915
19+
def assert_output_lengths(output: OutputType, expected_counts: Sequence[int]) -> None:
20+
"""Assert lengths of output dict fields against expected counts."""
21+
for field in ["box", "centroid", "contour", "prob", "type"]:
22+
for i, expected in enumerate(expected_counts):
23+
assert len(output[field][i]) == expected, f"{field}[{i}] mismatch"
24+
25+
26+
def assert_output_equal(
27+
output_a: OutputType,
28+
output_b: OutputType,
29+
fields: Sequence[str],
30+
indices_a: Sequence[int],
31+
indices_b: Sequence[int],
32+
) -> None:
33+
"""Assert equality of arrays across outputs for given fields/indices."""
34+
for field in fields:
35+
for i_a, i_b in zip(indices_a, indices_b, strict=False):
36+
left = output_a[field][i_a]
37+
right = output_b[field][i_b]
38+
assert all(
39+
np.array_equal(a, b) for a, b in zip(left, right, strict=False)
40+
), f"{field}[{i_a}] vs {field}[{i_b}] mismatch"
41+
42+
43+
def assert_predictions_and_boxes(
44+
output: OutputType, expected_counts: Sequence[int], *, is_zarr: bool = False
45+
) -> None:
46+
"""Assert predictions maxima and box lengths against expected counts."""
47+
# predictions maxima
48+
for idx, expected in enumerate(expected_counts):
49+
if is_zarr and idx == 2:
50+
# zarr output doesn't store predictions for patch 2
51+
continue
52+
assert np.max(output["predictions"][idx][:]) == expected, (
53+
f"predictions[{idx}] mismatch"
54+
)
55+
56+
# box lengths
57+
for idx, expected in enumerate(expected_counts):
58+
if is_zarr and idx < 2:
59+
# for zarr, compare boxes only for patches 0 and 1
60+
continue
61+
assert len(output["box"][idx]) == expected, f"box[{idx}] mismatch"
62+
63+
64+
def test_functionality_patch_mode(
1965
remote_sample: Callable, track_tmp_path: Path
2066
) -> None:
2167
"""Patch mode functionality test for nuclei instance segmentor."""
@@ -24,178 +70,105 @@ def test_functionality_patch_mode( # noqa: PLR0915
2470
size = (256, 256)
2571
resolution = 0.25
2672
units: Final = "mpp"
73+
2774
patch1 = mini_wsi.read_rect(
28-
location=(0, 0),
29-
size=size,
30-
resolution=resolution,
31-
units=units,
75+
location=(0, 0), size=size, resolution=resolution, units=units
3276
)
3377
patch2 = mini_wsi.read_rect(
34-
location=(512, 512),
35-
size=size,
36-
resolution=resolution,
37-
units=units,
78+
location=(512, 512), size=size, resolution=resolution, units=units
3879
)
39-
40-
# Test dummy input, should result in no output segmentation
4180
patch3 = np.zeros_like(patch1)
42-
43-
patches = np.stack(arrays=[patch1, patch2, patch3], axis=0)
81+
patches = np.stack([patch1, patch2, patch3], axis=0)
4482

4583
inst_segmentor = NucleusInstanceSegmentor(
46-
batch_size=1,
47-
num_workers=0,
48-
model="hovernet_fast-pannuke",
84+
batch_size=1, num_workers=0, model="hovernet_fast-pannuke"
4985
)
50-
output = inst_segmentor.run(
51-
images=patches,
52-
patch_mode=True,
53-
device=device,
54-
output_type="dict",
86+
output_dict = inst_segmentor.run(
87+
images=patches, patch_mode=True, device=device, output_type="dict"
5588
)
5689

57-
assert np.max(output["predictions"][0][:]) == 41
58-
assert np.max(output["predictions"][1][:]) == 17
59-
assert np.max(output["predictions"][2][:]) == 0
60-
61-
assert len(output["box"][0]) == 41
62-
assert len(output["box"][1]) == 17
63-
assert len(output["box"][2]) == 0
64-
65-
assert len(output["centroid"][0]) == 41
66-
assert len(output["centroid"][1]) == 17
67-
assert len(output["centroid"][2]) == 0
68-
69-
assert len(output["contour"][0]) == 41
70-
assert len(output["contour"][1]) == 17
71-
assert len(output["contour"][2]) == 0
72-
73-
assert len(output["prob"][0]) == 41
74-
assert len(output["prob"][1]) == 17
75-
assert len(output["prob"][2]) == 0
90+
expected_counts = [41, 17, 0]
7691

77-
assert len(output["type"][0]) == 41
78-
assert len(output["type"][1]) == 17
79-
assert len(output["type"][2]) == 0
92+
assert_predictions_and_boxes(output_dict, expected_counts, is_zarr=False)
93+
assert_output_lengths(output_dict, expected_counts)
8094

81-
output_ = output
82-
83-
output = inst_segmentor.run(
95+
# Zarr output comparison
96+
output_zarr = inst_segmentor.run(
8497
images=patches,
8598
patch_mode=True,
8699
device=device,
87100
output_type="zarr",
88101
save_dir=track_tmp_path / "patch_output_zarr",
89102
)
103+
output_zarr = zarr.open(output_zarr, mode="r")
104+
assert_predictions_and_boxes(output_zarr, expected_counts, is_zarr=True)
90105

91-
output = zarr.open(output, mode="r")
92-
93-
assert np.max(output["predictions"][0][:]) == 41
94-
assert np.max(output["predictions"][1][:]) == 17
95-
96-
assert all(
97-
np.array_equal(a, b)
98-
for a, b in zip(output["box"][0], output_["box"][0], strict=False)
106+
assert_output_equal(
107+
output_zarr,
108+
output_dict,
109+
fields=["box", "centroid", "contour", "prob", "type"],
110+
indices_a=[0, 1, 2],
111+
indices_b=[0, 1, 2],
99112
)
100-
assert all(
101-
np.array_equal(a, b)
102-
for a, b in zip(output["box"][1], output_["box"][1], strict=False)
103-
)
104-
assert len(output["box"][2]) == 0
105113

106-
assert all(
107-
np.array_equal(a, b)
108-
for a, b in zip(output["centroid"][0], output_["centroid"][0], strict=False)
109-
)
110-
assert all(
111-
np.array_equal(a, b)
112-
for a, b in zip(output["centroid"][1], output_["centroid"][1], strict=False)
113-
)
114-
115-
assert all(
116-
np.array_equal(a, b)
117-
for a, b in zip(output["contour"][0], output_["contour"][0], strict=False)
118-
)
119-
assert all(
120-
np.array_equal(a, b)
121-
for a, b in zip(output["contour"][1], output_["contour"][1], strict=False)
122-
)
123-
124-
assert all(
125-
np.array_equal(a, b)
126-
for a, b in zip(output["prob"][0], output_["prob"][0], strict=False)
127-
)
128-
assert all(
129-
np.array_equal(a, b)
130-
for a, b in zip(output["prob"][1], output_["prob"][1], strict=False)
131-
)
132-
133-
assert all(
134-
np.array_equal(a, b)
135-
for a, b in zip(output["type"][0], output_["type"][0], strict=False)
136-
)
137-
assert all(
138-
np.array_equal(a, b)
139-
for a, b in zip(output["type"][1], output_["type"][1], strict=False)
140-
)
141-
142-
inst_segmentor = NucleusInstanceSegmentor(
143-
batch_size=1,
144-
num_workers=0,
145-
model="hovernet_fast-pannuke",
146-
)
147-
output = inst_segmentor.run(
114+
# AnnotationStore output comparison
115+
output_ann = inst_segmentor.run(
148116
images=patches,
149117
patch_mode=True,
150118
device=device,
151119
output_type="annotationstore",
152120
save_dir=track_tmp_path / "patch_output_annotationstore",
153121
)
122+
assert len(output_ann) == 3
123+
assert output_ann[0] == track_tmp_path / "patch_output_annotationstore" / "0.db"
154124

155-
assert output[0] == track_tmp_path / "patch_output_annotationstore" / "0.db"
156-
assert len(output) == 3
157-
store_ = SQLiteStore.open(output[0])
158-
annotations_ = store_.values()
159-
annotations_geometry_type = [
160-
str(annotation_.geometry_type) for annotation_ in annotations_
161-
]
162-
assert "Polygon" in annotations_geometry_type
163-
164-
annotations_list = list(annotations_)
165-
ann_properties = [ann.properties for ann in annotations_list]
166-
167-
result = {}
168-
for d in ann_properties:
169-
for key, value in d.items():
170-
result.setdefault(key, []).append(value)
171-
172-
polygons = [ann.geometry for ann in annotations_list]
173-
result["contour"] = [list(poly.exterior.coords) for poly in polygons]
174-
175-
assert all(
176-
np.array_equal(a, b)
177-
for a, b in zip(result["box"], output_["box"][0], strict=False)
178-
)
179-
180-
assert all(
181-
np.array_equal(a, b)
182-
for a, b in zip(result["centroid"], output_["centroid"][0], strict=False)
183-
)
184-
185-
assert all(
186-
np.array_equal(a, b)
187-
for a, b in zip(result["prob"], output_["prob"][0], strict=False)
188-
)
189-
190-
assert all(
191-
np.array_equal(a, b)
192-
for a, b in zip(result["type"], output_["type"][0], strict=False)
193-
)
194-
195-
assert all(
196-
np.array_equal(
197-
np.array(a[:-1], dtype=int), # discard last point
198-
np.array(b, dtype=int),
125+
for patch_idx, db_path in enumerate(output_ann):
126+
assert (
127+
db_path
128+
== track_tmp_path / "patch_output_annotationstore" / f"{patch_idx}.db"
199129
)
200-
for a, b in zip(result["contour"], output_["contour"][0], strict=False)
201-
)
130+
store_ = SQLiteStore.open(db_path)
131+
annotations_ = store_.values()
132+
annotations_geometry_type = [
133+
str(annotation_.geometry_type) for annotation_ in annotations_
134+
]
135+
annotations_list = list(annotations_)
136+
if expected_counts[patch_idx] > 0:
137+
assert "Polygon" in annotations_geometry_type
138+
139+
# Build result dict from annotation properties
140+
result = {}
141+
for ann in annotations_list:
142+
for key, value in ann.properties.items():
143+
result.setdefault(key, []).append(value)
144+
result["contour"] = [
145+
list(poly.exterior.coords)
146+
for poly in (a.geometry for a in annotations_list)
147+
]
148+
149+
# wrap it to make it compatible to assert_output_lengths
150+
result_ = {
151+
field: [result[field]]
152+
for field in ["box", "centroid", "contour", "prob", "type"]
153+
}
154+
155+
# Lengths and equality checks for this patch
156+
assert_output_lengths(result_, [expected_counts[patch_idx]])
157+
assert_output_equal(
158+
result_,
159+
output_dict,
160+
fields=["box", "centroid", "prob", "type"],
161+
indices_a=[0],
162+
indices_b=[patch_idx],
163+
)
164+
165+
# Contour check (discard last point)
166+
assert all(
167+
np.array_equal(np.array(a[:-1], dtype=int), np.array(b, dtype=int))
168+
for a, b in zip(
169+
result["contour"], output_dict["contour"][patch_idx], strict=False
170+
)
171+
)
172+
else:
173+
assert annotations_geometry_type == []
174+
assert annotations_list == []

0 commit comments

Comments
 (0)