Skip to content

Commit d63a7cc

Browse files
committed
refactor code and improve typing
1 parent 884bdf0 commit d63a7cc

File tree

6 files changed

+140
-102
lines changed

6 files changed

+140
-102
lines changed

tests/engines/test_nucleus_detection_engine.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ def test_nucleus_detector_patch(
148148
wsi_reader = WSIReader.open(mini_wsi_svs)
149149
patch_1 = wsi_reader.read_rect((0, 0), (252, 252), resolution=0.5, units="mpp")
150150
patch_2 = wsi_reader.read_rect((252, 252), (252, 252), resolution=0.5, units="mpp")
151+
patch_3 = np.zeros((252, 252, 3), dtype=np.uint8)
151152

152153
pretrained_model = "mapde-conic"
153154

@@ -159,7 +160,7 @@ def test_nucleus_detector_patch(
159160
device=device,
160161
output_type="annotationstore",
161162
memory_threshold=50,
162-
images=[patch_1, patch_2],
163+
images=[patch_1, patch_2, patch_3],
163164
save_dir=save_dir,
164165
overwrite=True,
165166
class_dict=None,
@@ -173,6 +174,10 @@ def test_nucleus_detector_patch(
173174
assert len(store_2.values()) == 52
174175
store_2.close()
175176

177+
store_3 = SQLiteStore.open(save_dir / "2.db")
178+
assert len(store_3.values()) == 0
179+
store_3.close()
180+
176181
imwrite(save_dir / "patch_0.png", patch_1)
177182
imwrite(save_dir / "patch_1.png", patch_2)
178183
_ = nucleus_detector.run(

tiatoolbox/data/pretrained_model.yaml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -815,7 +815,7 @@ mapde-crchisto:
815815
threshold_abs: 250
816816
num_classes: 1
817817
postproc_tile_shape: [ 2048, 2048 ]
818-
output_class_dict: {
818+
class_dict: {
819819
0: "nucleus"
820820
}
821821
ioconfig:
@@ -840,7 +840,7 @@ mapde-conic:
840840
threshold_abs: 205
841841
num_classes: 1
842842
postproc_tile_shape: [ 2048, 2048 ]
843-
output_class_dict: {
843+
class_dict: {
844844
0: "nucleus"
845845
}
846846
ioconfig:
@@ -866,7 +866,7 @@ sccnn-crchisto:
866866
threshold_abs: 0.20
867867
patch_output_shape: [ 13, 13 ]
868868
postproc_tile_shape: [ 2048, 2048 ]
869-
output_class_dict: {
869+
class_dict: {
870870
0: "nucleus"
871871
}
872872
ioconfig:
@@ -892,7 +892,7 @@ sccnn-conic:
892892
threshold_abs: 0.05
893893
patch_output_shape: [ 13, 13 ]
894894
postproc_tile_shape: [ 2048, 2048 ]
895-
output_class_dict: {
895+
class_dict: {
896896
0: "nucleus"
897897
}
898898
ioconfig:

tiatoolbox/models/architecture/mapde.py

Lines changed: 11 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,9 @@
1111
import numpy as np
1212
import torch
1313
import torch.nn.functional as F # noqa: N812
14-
from skimage.feature import peak_local_max
1514

1615
from tiatoolbox.models.architecture.micronet import MicroNet
16+
from tiatoolbox.models.architecture.utils import peak_detection_da_map_overlap
1717

1818

1919
class MapDe(MicroNet):
@@ -79,15 +79,15 @@ def __init__(
7979
threshold_abs: float = 250,
8080
num_classes: int = 1,
8181
postproc_tile_shape: tuple[int, int] = (2048, 2048),
82-
output_class_dict: dict[int, str] | None = None,
82+
class_dict: dict[int, str] | None = None,
8383
) -> None:
8484
"""Initialize :class:`MapDe`."""
8585
super().__init__(
8686
num_output_channels=num_classes * 2,
8787
num_input_channels=num_input_channels,
8888
out_activation="relu",
8989
)
90-
self.output_class_dict = output_class_dict
90+
self.output_class_dict = class_dict
9191
self.postproc_tile_shape = postproc_tile_shape
9292

9393
dist_filter = np.array(
@@ -249,11 +249,6 @@ def postproc(
249249
Builds a processed mask per input channel, runs peak_local_max then
250250
writes 1.0 at peak pixels.
251251
252-
Can be called inside Dask.da.map_overlap on a padded NumPy block:
253-
(h_pad, w_pad, C) to process large prediction maps in chunks.
254-
Keeps only centroids whose (row,col) lie in the interior window:
255-
rows [depth_h : depth_h + core_h), cols [depth_w : depth_w + core_w)
256-
257252
Returns same spatial shape as the input block
258253
259254
Args:
@@ -268,40 +263,14 @@ def postproc(
268263
Returns:
269264
out: NumPy array (H, W, C) with 1.0 at peaks, 0 elsewhere.
270265
"""
271-
block_height, block_width, block_channels = block.shape
272-
273-
# --- derive core (pre-overlap) size for THIS block ---
274-
if block_info is None:
275-
core_h = block_height - 2 * depth_h
276-
core_w = block_width - 2 * depth_w
277-
else:
278-
info = block_info[0]
279-
locs = info[
280-
"array-location"
281-
] # a list of (start, stop) coordinates per axis
282-
core_h = int(locs[0][1] - locs[0][0]) # r1 - r0
283-
core_w = int(locs[1][1] - locs[1][0])
284-
285-
rmin, rmax = depth_h, depth_h + core_h
286-
cmin, cmax = depth_w, depth_w + core_w
287-
288-
out = np.zeros((block_height, block_width, block_channels), dtype=np.float32)
289-
290-
for ch in range(block_channels):
291-
img = np.asarray(block[..., ch]) # NumPy 2D view
292-
293-
coords = peak_local_max(
294-
img,
295-
min_distance=self.min_distance,
296-
threshold_abs=self.threshold_abs,
297-
exclude_border=False,
298-
)
299-
300-
for r, c in coords:
301-
if (rmin <= r < rmax) and (cmin <= c < cmax):
302-
out[r, c, ch] = 1.0
303-
304-
return out
266+
return peak_detection_da_map_overlap(
267+
block,
268+
min_distance=self.min_distance,
269+
threshold_abs=self.threshold_abs,
270+
block_info=block_info,
271+
depth_h=depth_h,
272+
depth_w=depth_w,
273+
)
305274

306275
@staticmethod
307276
def infer_batch(

tiatoolbox/models/architecture/sccnn.py

Lines changed: 15 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,15 @@
1010
from __future__ import annotations
1111

1212
from collections import OrderedDict
13+
from typing import TYPE_CHECKING
14+
15+
if TYPE_CHECKING:
16+
import numpy as np
1317

14-
import numpy as np
1518
import torch
16-
from skimage.feature import peak_local_max
1719
from torch import nn
1820

21+
from tiatoolbox.models.architecture.utils import peak_detection_da_map_overlap
1922
from tiatoolbox.models.models_abc import ModelABC
2023

2124

@@ -92,7 +95,7 @@ def __init__(
9295
min_distance: int = 6,
9396
threshold_abs: float = 0.20,
9497
postproc_tile_shape: tuple[int, int] = (2048, 2048),
95-
output_class_dict: dict[int, str] | None = None,
98+
class_dict: dict[int, str] | None = None,
9699
) -> None:
97100
"""Initialize :class:`SCCNN`."""
98101
super().__init__()
@@ -102,7 +105,7 @@ def __init__(
102105
self.out_height = out_height
103106
self.out_width = out_width
104107
self.postproc_tile_shape = postproc_tile_shape
105-
self.output_class_dict = output_class_dict
108+
self.output_class_dict = class_dict
106109

107110
# Create mesh grid and convert to 3D vector
108111
x, y = torch.meshgrid(
@@ -341,11 +344,6 @@ def postproc(
341344
Builds a processed mask per input channel, runs peak_local_max then
342345
writes 1.0 at peak pixels.
343346
344-
Can be called inside Dask.da.map_overlap on a padded NumPy block:
345-
(h_pad, w_pad, C) to process large prediction maps in chunks.
346-
Keeps only centroids whose (row,col) lie in the interior window:
347-
rows [depth_h : depth_h + core_h), cols [depth_w : depth_w + core_w)
348-
349347
Returns same spatial shape as the input block
350348
351349
Args:
@@ -360,40 +358,14 @@ def postproc(
360358
Returns:
361359
out: NumPy array (H, W, C) with 1.0 at peaks, 0 elsewhere.
362360
"""
363-
block_height, block_width, block_channels = block.shape
364-
365-
# --- derive core (pre-overlap) size for THIS block ---
366-
if block_info is None:
367-
core_h = block_height - 2 * depth_h
368-
core_w = block_width - 2 * depth_w
369-
else:
370-
info = block_info[0]
371-
locs = info[
372-
"array-location"
373-
] # a list of (start, stop) coordinates per axis
374-
core_h = int(locs[0][1] - locs[0][0]) # r1 - r0
375-
core_w = int(locs[1][1] - locs[1][0])
376-
377-
rmin, rmax = depth_h, depth_h + core_h
378-
cmin, cmax = depth_w, depth_w + core_w
379-
380-
out = np.zeros((block_height, block_width, block_channels), dtype=np.float32)
381-
382-
for ch in range(block_channels):
383-
img = np.asarray(block[..., ch]) # NumPy 2D view
384-
385-
coords = peak_local_max(
386-
img,
387-
min_distance=self.min_distance,
388-
threshold_abs=self.threshold_abs,
389-
exclude_border=False,
390-
)
391-
392-
for r, c in coords:
393-
if (rmin <= r < rmax) and (cmin <= c < cmax):
394-
out[r, c, ch] = 1.0
395-
396-
return out
361+
return peak_detection_da_map_overlap(
362+
block,
363+
min_distance=self.min_distance,
364+
threshold_abs=self.threshold_abs,
365+
block_info=block_info,
366+
depth_h=depth_h,
367+
depth_w=depth_w,
368+
)
397369

398370
@staticmethod
399371
def infer_batch(

tiatoolbox/models/architecture/utils.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import numpy as np
99
import torch
10+
from skimage.feature import peak_local_max
1011
from torch import nn
1112

1213
from tiatoolbox import logger
@@ -251,3 +252,74 @@ def argmax_last_axis(image: np.ndarray) -> np.ndarray:
251252
252253
"""
253254
return image.argmax(axis=-1)
255+
256+
257+
def peak_detection_da_map_overlap(
258+
block: np.ndarray,
259+
min_distance: int,
260+
threshold_abs: float | None = None,
261+
threshold_rel: float | None = None,
262+
block_info: dict | None = None,
263+
depth_h: int = 0,
264+
depth_w: int = 0,
265+
) -> np.ndarray:
266+
"""Post-processing function for peak detection.
267+
268+
Builds a processed mask per input channel. Runs peak_local_max then
269+
writes 1.0 at peak pixels.
270+
271+
Can be called from Dask.da.map_overlap on a padded NumPy block
272+
(h_pad, w_pad, C) to process large prediction maps in chunks with overlap.
273+
Keeps only centroids whose (row,col) lie in the interior window:
274+
rows [depth_h : depth_h + core_h), cols [depth_w : depth_w + core_w)
275+
276+
Returns same spatial shape as the input block
277+
278+
Args:
279+
block: NumPy array (H, W, C).
280+
min_distance: Minimum number of pixels separating peaks.
281+
threshold_abs: Minimum intensity of peaks. By default, None.
282+
threshold_rel: Minimum relative intensity of peaks. By default, None.
283+
block_info: Dask block info dict.
284+
Only used when called from dask.array.map_overlap.
285+
depth_h: Halo size in pixels for height (rows).
286+
Only used when called from dask.array.map_overlap.
287+
depth_w: Halo size in pixels for width (cols).
288+
Only used when it's called from dask.array.map_overlap.
289+
290+
Returns:
291+
out: NumPy array (H, W, C) with 1.0 at peaks, 0 elsewhere.
292+
"""
293+
block_height, block_width, block_channels = block.shape
294+
295+
# --- derive core (pre-overlap) size for THIS block ---
296+
if block_info is None:
297+
core_h = block_height - 2 * depth_h
298+
core_w = block_width - 2 * depth_w
299+
else:
300+
info = block_info[0]
301+
locs = info["array-location"] # a list of (start, stop) coordinates per axis
302+
core_h = int(locs[0][1] - locs[0][0]) # r1 - r0
303+
core_w = int(locs[1][1] - locs[1][0])
304+
305+
rmin, rmax = depth_h, depth_h + core_h
306+
cmin, cmax = depth_w, depth_w + core_w
307+
308+
out = np.zeros((block_height, block_width, block_channels), dtype=np.float32)
309+
310+
for ch in range(block_channels):
311+
img = np.asarray(block[..., ch]) # NumPy 2D view
312+
313+
coords = peak_local_max(
314+
img,
315+
min_distance=min_distance,
316+
threshold_abs=threshold_abs,
317+
threshold_rel=threshold_rel,
318+
exclude_border=False,
319+
)
320+
321+
for r, c in coords:
322+
if (rmin <= r < rmax) and (cmin <= c < cmax):
323+
out[r, c, ch] = 1.0
324+
325+
return out

0 commit comments

Comments
 (0)