Skip to content

Commit d2eae54

Browse files
committed
update model postproc function
1 parent b58dd88 commit d2eae54

File tree

5 files changed

+138
-17
lines changed

5 files changed

+138
-17
lines changed

tests/models/test_arch_mapde.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,56 @@ def test_functionality(remote_sample: Callable) -> None:
8383
Path(weights_path).unlink()
8484

8585

86+
def test_postproc_params_override(remote_sample: Callable) -> None:
87+
"""Test MapDe post-processing with overridden parameters."""
88+
sample_wsi = str(remote_sample("wsi1_2k_2k_svs"))
89+
reader = WSIReader.open(sample_wsi)
90+
91+
# * test fast mode (architecture used in PanNuke paper)
92+
patch = reader.read_bounds(
93+
(0, 0, 252, 252),
94+
resolution=0.50,
95+
units="mpp",
96+
coord_space="resolution",
97+
)
98+
99+
model, weight_path = _load_mapde(name="mapde-conic")
100+
patch = model.preproc(patch)
101+
batch = torch.from_numpy(patch)[None]
102+
raw_output = model.infer_batch(model, batch, device=select_device(on_gpu=ON_GPU))
103+
104+
output_normal = model.postproc(raw_output[0])
105+
(
106+
ys_normal,
107+
xs_normal,
108+
_,
109+
) = np.nonzero(output_normal)
110+
111+
# Use higher threshold should result in less detections
112+
output_high_threshold = model.postproc(raw_output[0], threshold_abs=500)
113+
(
114+
ys_high_threshold,
115+
xs_high_threshold,
116+
_,
117+
) = np.nonzero(output_high_threshold)
118+
119+
# Use bigger min_distance should result in less detections
120+
output_large_min_distance = model.postproc(raw_output[0], min_distance=9)
121+
(
122+
ys_large_min_distance,
123+
xs_large_min_distance,
124+
_,
125+
) = np.nonzero(output_large_min_distance)
126+
127+
assert len(xs_high_threshold) < len(xs_normal)
128+
assert len(ys_high_threshold) < len(ys_normal)
129+
130+
assert len(xs_large_min_distance) < len(xs_normal)
131+
assert len(ys_large_min_distance) < len(ys_normal)
132+
133+
Path(weight_path).unlink()
134+
135+
86136
def test_multiclass_output() -> None:
87137
"""Test the architecture for multi-class output."""
88138
multiclass_model = MapDe(num_input_channels=3, num_classes=3)

tests/models/test_arch_sccnn.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,3 +87,36 @@ def test_functionality(remote_sample: Callable) -> None:
8787
ys, xs, _ = np.nonzero(output)
8888
np.testing.assert_array_equal(xs, np.array([]))
8989
np.testing.assert_array_equal(ys, np.array([]))
90+
91+
92+
def test_postproc_params_override(remote_sample: Callable) -> None:
93+
"""Test postproc parameters override."""
94+
sample_wsi = str(remote_sample("wsi1_2k_2k_svs"))
95+
reader = WSIReader.open(sample_wsi)
96+
97+
# * test fast mode (architecture used in PanNuke paper)
98+
patch = reader.read_bounds(
99+
(30, 30, 61, 61),
100+
resolution=0.25,
101+
units="mpp",
102+
coord_space="resolution",
103+
)
104+
model = _load_sccnn(name="sccnn-crchisto")
105+
patch = model.preproc(patch)
106+
batch = torch.from_numpy(patch)[None]
107+
raw_output = model.infer_batch(
108+
model,
109+
batch,
110+
device=select_device(on_gpu=env_detection.has_gpu()),
111+
)
112+
# Override to a high threshold to get no detections
113+
output = model.postproc(raw_output[0], threshold_abs=0.9)
114+
ys, xs, _ = np.nonzero(output)
115+
np.testing.assert_array_equal(xs, np.array([]))
116+
np.testing.assert_array_equal(ys, np.array([]))
117+
118+
# Override with small min_distance
119+
output = model.postproc(raw_output[0], min_distance=1)
120+
ys, xs, _ = np.nonzero(output)
121+
np.testing.assert_array_equal(xs, np.array([8]))
122+
np.testing.assert_array_equal(ys, np.array([7]))

tiatoolbox/models/architecture/mapde.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,9 @@ def forward(self: MapDe, input_tensor: torch.Tensor) -> torch.Tensor:
240240
def postproc(
241241
self: MapDe,
242242
block: np.ndarray,
243+
min_distance: int | None = None,
244+
threshold_abs: float | None = None,
245+
threshold_rel: float | None = None,
243246
block_info: dict | None = None,
244247
depth_h: int = 0,
245248
depth_w: int = 0,
@@ -253,22 +256,37 @@ def postproc(
253256
254257
Args:
255258
block (np.ndarray):
256-
NumPy array (H, W, C).
257-
block_info (dict):
259+
shape (H, W, C).
260+
min_distance (int | None):
261+
The minimal allowed distance separating peaks.
262+
threshold_abs (float | None):
263+
Minimum intensity of peaks.
264+
threshold_rel (float | None):
265+
Minimum intensity of peaks.
266+
block_info (dict | None):
258267
Dask block info dict. Only used when called from
259268
dask.array.map_overlap.
260-
depth_h: Halo size in pixels for height (rows).
261-
Only used when it's called from dask.array.map_overlap.
262-
depth_w: Halo size in pixels for width (cols).
263-
Only used when it's called from dask.array.map_overlap.
269+
depth_h (int):
270+
Halo size in pixels for height (rows). Only used
271+
when it's called from dask.array.map_overlap.
272+
depth_w (int):
273+
Halo size in pixels for width (cols). Only used
274+
when it's called from dask.array.map_overlap.
264275
265276
Returns:
266277
out: NumPy array (H, W, C) with 1.0 at peaks, 0 elsewhere.
267278
"""
279+
min_distance_to_use = (
280+
self.min_distance if min_distance is None else min_distance
281+
)
282+
threshold_abs_to_use = (
283+
self.threshold_abs if threshold_abs is None else threshold_abs
284+
)
268285
return peak_detection_map_overlap(
269286
block,
270-
min_distance=self.min_distance,
271-
threshold_abs=self.threshold_abs,
287+
min_distance=min_distance_to_use,
288+
threshold_abs=threshold_abs_to_use,
289+
threshold_rel=threshold_rel,
272290
block_info=block_info,
273291
depth_h=depth_h,
274292
depth_w=depth_w,

tiatoolbox/models/architecture/sccnn.py

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,9 @@ def spatially_constrained_layer1(
335335
def postproc(
336336
self: SCCNN,
337337
block: np.ndarray,
338+
min_distance: int | None = None,
339+
threshold_abs: float | None = None,
340+
threshold_rel: float | None = None,
338341
block_info: dict | None = None,
339342
depth_h: int = 0,
340343
depth_w: int = 0,
@@ -347,21 +350,38 @@ def postproc(
347350
Returns same spatial shape as the input block
348351
349352
Args:
350-
block: NumPy array (H, W, C).
351-
block_info: Dask block info dict. Only used when called inside
353+
block (np.ndarray):
354+
shape (H, W, C).
355+
min_distance (int | None):
356+
The minimal allowed distance separating peaks.
357+
threshold_abs (float | None):
358+
Minimum intensity of peaks.
359+
threshold_rel (float | None):
360+
Minimum intensity of peaks.
361+
block_info (dict | None):
362+
Dask block info dict. Only used when called from
352363
dask.array.map_overlap.
353-
depth_h: Halo size in pixels for height (rows).
354-
Only used when it's called inside dask.array.map_overlap.
355-
depth_w: Halo size in pixels for width (cols).
356-
Only used when it's called inside dask.array.map_overlap.
364+
depth_h (int):
365+
Halo size in pixels for height (rows). Only used
366+
when it's called from dask.array.map_overlap.
367+
depth_w (int):
368+
Halo size in pixels for width (cols). Only used
369+
when it's called from dask.array.map_overlap.
357370
358371
Returns:
359372
out: NumPy array (H, W, C) with 1.0 at peaks, 0 elsewhere.
360373
"""
374+
min_distance_to_use = (
375+
self.min_distance if min_distance is None else min_distance
376+
)
377+
threshold_abs_to_use = (
378+
self.threshold_abs if threshold_abs is None else threshold_abs
379+
)
361380
return peak_detection_map_overlap(
362381
block,
363-
min_distance=self.min_distance,
364-
threshold_abs=self.threshold_abs,
382+
min_distance=min_distance_to_use,
383+
threshold_abs=threshold_abs_to_use,
384+
threshold_rel=threshold_rel,
365385
block_info=block_info,
366386
depth_h=depth_h,
367387
depth_w=depth_w,

tiatoolbox/models/engine/nucleus_detector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -679,7 +679,7 @@ def save_detection_arrays_to_store(
679679
classes = np.atleast_1d(np.asarray(classes))
680680
probs = np.atleast_1d(np.asarray(probs))
681681

682-
if not (len(xs) == len(ys) == len(classes) == len(probs)):
682+
if not len(xs) == len(ys) == len(classes) == len(probs):
683683
msg = "Detection record lengths are misaligned."
684684
raise ValueError(msg)
685685

0 commit comments

Comments
 (0)