Skip to content

Commit f5b1885

Browse files
committed
update patch mode processing
1 parent 17e5422 commit f5b1885

File tree

5 files changed

+87
-47
lines changed

5 files changed

+87
-47
lines changed

test.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,46 @@
1-
import pathlib
1+
from pathlib import Path
22

33
from tiatoolbox.models.engine.nucleus_detector import (
44
NucleusDetector,
55
)
66
from tiatoolbox.utils import env_detection as toolbox_env
7+
from tiatoolbox.wsicore.wsireader import WSIReader
78

8-
ON_GPU = not toolbox_env.running_on_ci() and toolbox_env.has_gpu()
9+
from tiatoolbox.utils.misc import imwrite
910

11+
ON_GPU = not toolbox_env.running_on_ci() and toolbox_env.has_gpu()
1012

1113
if __name__ == "__main__":
12-
# model_name = "sccnn-crchisto"
13-
model_name = "mapde-conic"
14+
model_name = "sccnn-crchisto"
15+
# model_name = "mapde-conic"
16+
17+
18+
# test_image_path = "/media/u1910100/data/slides/CMU-1-Small-Region.svs"
19+
# reader = WSIReader.open(test_image_path)
20+
21+
# patch_1 = reader.read_region((1500, 1500), level=0, size=(31, 31))
22+
23+
# imwrite(Path("/media/u1910100/data/slides/patch_1.png"), patch_1)
24+
25+
# patch_2 = reader.read_region((1000, 1000), level=0, size=(31, 31))
26+
# imwrite(Path("/media/u1910100/data/slides/patch_2.png"), patch_2)
27+
28+
# patches = [
29+
# Path("/media/u1910100/data/slides/patch_1.png"),
30+
# Path("/media/u1910100/data/slides/patch_2.png"),
31+
# ]
32+
1433

1534
detector = NucleusDetector(model=model_name, batch_size=16, num_workers=8)
1635
detector.run(
17-
images=[pathlib.Path("/media/u1910100/data/slides/CMU-1-Small-Region.svs")],
36+
images=[Path("/media/u1910100/data/slides/wsi1_2k_2k.svs")],
37+
# images=patches,
1838
patch_mode=False,
1939
device="cuda",
20-
save_dir=pathlib.Path("/media/u1910100/data/overlays/test"),
40+
save_dir=Path("/media/u1910100/data/overlays/test"),
2141
overwrite=True,
2242
output_type="annotationstore",
2343
class_dict={0: "nucleus"},
2444
auto_get_mask=True,
25-
memory_threshold=80,
45+
memory_threshold=70,
2646
)

tiatoolbox/data/pretrained_model.yaml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -867,7 +867,6 @@ sccnn-crchisto:
867867
- { "units": "mpp", "resolution": 0.25 }
868868
output_resolutions:
869869
- { "units": "mpp", "resolution": 0.25 }
870-
tile_shape: [ 2048, 2048 ]
871870
patch_input_shape: [ 31, 31 ]
872871
patch_output_shape: [ 13, 13 ]
873872
stride_shape: [ 8, 8 ]
@@ -891,7 +890,6 @@ sccnn-conic:
891890
- { "units": "mpp", "resolution": 0.25 }
892891
output_resolutions:
893892
- { "units": "mpp", "resolution": 0.25 }
894-
tile_shape: [ 2048, 2048 ]
895893
patch_input_shape: [ 31, 31 ]
896894
patch_output_shape: [ 13, 13 ]
897895
stride_shape: [ 8, 8 ]

tiatoolbox/models/architecture/mapde.py

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -239,34 +239,44 @@ def forward(self: MapDe, input_tensor: torch.Tensor) -> torch.Tensor:
239239
def postproc(
240240
self: MapDe,
241241
block: np.ndarray,
242-
block_info: dict,
243-
depth_h: int,
244-
depth_w: int,
242+
block_info: dict | None = None,
243+
depth_h: int = 0,
244+
depth_w: int = 0,
245245
) -> np.ndarray:
246-
"""Runs inside Dask.da.map_overlap on a padded NumPy block: (h_pad, w_pad, C).
246+
""" MapDe post-processing function.
247247
248-
Builds a processed mask per channel, runs peak_local_max then
249-
writes 1.0 at centroid pixels.
250-
Keeps only centroids whose (row,col) lie in the interior window:
248+
Builds a processed mask per input channel, runs peak_local_max then
249+
writes 1.0 at peak pixels.
250+
251+
Can be called inside Dask.da.map_overlap on a padded NumPy block: (h_pad, w_pad, C)
252+
to process large prediction maps in chunks. Keeps only centroids whose (row,col)
253+
lie in the interior window:
251254
rows [depth_h : depth_h + core_h), cols [depth_w : depth_w + core_w)
252-
Returns same spatial shape as input block: (h_pad, w_pad, C), float32.
255+
256+
Returns same spatial shape as the input block
253257
254258
Args:
255-
block: NumPy array (H, W, C) with padded block data.
256-
block_info: Dask block info dict.
257-
depth_h: Halo size in pixels for height (rows).
258-
depth_w: Halo size in pixels for width (cols).
259+
block: NumPy array (H, W, C).
260+
block_info: Dask block info dict. Only used when called inside dask.array.map_overlap.
261+
depth_h: Halo size in pixels for height (rows).
262+
Only used when it's called inside dask.array.map_overlap.
263+
depth_w: Halo size in pixels for width (cols).
264+
Only used when it's called inside dask.array.map_overlap.
259265
260266
Returns:
261-
out: NumPy array (H, W, C) with 1 at centroids, 0 elsewhere.
267+
out: NumPy array (H, W, C) with 1.0 at peaks, 0 elsewhere.
262268
"""
263269
block_height, block_width, block_channels = block.shape
264270

265271
# --- derive core (pre-overlap) size for THIS block ---
266-
info = block_info[0]
267-
locs = info["array-location"] # a list of (start, stop) coordinates per axis
268-
core_h = int(locs[0][1] - locs[0][0]) # r1 - r0
269-
core_w = int(locs[1][1] - locs[1][0])
272+
if block_info is None:
273+
core_h = block_height - 2 * depth_h
274+
core_w = block_width - 2 * depth_w
275+
else:
276+
info = block_info[0]
277+
locs = info["array-location"] # a list of (start, stop) coordinates per axis
278+
core_h = int(locs[0][1] - locs[0][0]) # r1 - r0
279+
core_w = int(locs[1][1] - locs[1][0])
270280

271281
rmin, rmax = depth_h, depth_h + core_h
272282
cmin, cmax = depth_w, depth_w + core_w

tiatoolbox/models/architecture/sccnn.py

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -329,34 +329,44 @@ def spatially_constrained_layer1(
329329
def postproc(
330330
self: SCCNN,
331331
block: np.ndarray,
332-
block_info: dict,
333-
depth_h: int,
334-
depth_w: int,
332+
block_info: dict | None = None,
333+
depth_h: int = 0,
334+
depth_w: int = 0,
335335
) -> np.ndarray:
336-
"""Runs inside Dask.da.map_overlap on a padded NumPy block: (h_pad, w_pad, C).
336+
""" SCCNN post-processing function.
337337
338-
Builds a processed mask per channel, runs peak_local_max then
339-
writes 1.0 at centroid pixels.
340-
Keeps only centroids whose (row,col) lie in the interior window:
338+
Builds a processed mask per input channel, runs peak_local_max then
339+
writes 1.0 at peak pixels.
340+
341+
Can be called inside Dask.da.map_overlap on a padded NumPy block: (h_pad, w_pad, C)
342+
to process large prediction maps in chunks. Keeps only centroids whose (row,col)
343+
lie in the interior window:
341344
rows [depth_h : depth_h + core_h), cols [depth_w : depth_w + core_w)
342-
Returns same spatial shape as input block: (h_pad, w_pad, C), float32.
345+
346+
Returns same spatial shape as the input block
343347
344348
Args:
345-
block: NumPy array (H, W, C) with padded block data.
346-
block_info: Dask block info dict.
347-
depth_h: Halo size in pixels for height (rows).
348-
depth_w: Halo size in pixels for width (cols).
349+
block: NumPy array (H, W, C).
350+
block_info: Dask block info dict. Only used when called inside dask.array.map_overlap.
351+
depth_h: Halo size in pixels for height (rows).
352+
Only used when it's called inside dask.array.map_overlap.
353+
depth_w: Halo size in pixels for width (cols).
354+
Only used when it's called inside dask.array.map_overlap.
349355
350356
Returns:
351-
out: NumPy array (H, W, C) with 1 at centroids, 0 elsewhere.
357+
out: NumPy array (H, W, C) with 1.0 at peaks, 0 elsewhere.
352358
"""
353359
block_height, block_width, block_channels = block.shape
354360

355361
# --- derive core (pre-overlap) size for THIS block ---
356-
info = block_info[0]
357-
locs = info["array-location"] # a list of (start, stop) coordinates per axis
358-
core_h = int(locs[0][1] - locs[0][0]) # r1 - r0
359-
core_w = int(locs[1][1] - locs[1][0])
362+
if block_info is None:
363+
core_h = block_height - 2 * depth_h
364+
core_w = block_width - 2 * depth_w
365+
else:
366+
info = block_info[0]
367+
locs = info["array-location"] # a list of (start, stop) coordinates per axis
368+
core_h = int(locs[0][1] - locs[0][0]) # r1 - r0
369+
core_w = int(locs[1][1] - locs[1][0])
360370

361371
rmin, rmax = depth_h, depth_h + core_h
362372
cmin, cmax = depth_w, depth_w + core_w

tiatoolbox/models/engine/nucleus_detector.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -96,11 +96,11 @@ def __init__(
9696

9797
def post_process_patches(
9898
self: NucleusDetector,
99-
raw_predictions: da.Array,
99+
raw_predictions: list[da.Array],
100100
prediction_shape: tuple[int, ...],
101101
prediction_dtype: type,
102102
**kwargs: Unpack[SemanticSegmentorRunParams],
103-
) -> list[pd.DataFrame]:
103+
) -> list[np.ndarray]:
104104
"""Define how to post-process patch predictions.
105105
106106
Args:
@@ -117,7 +117,7 @@ def post_process_patches(
117117
_ = prediction_dtype
118118

119119
batch_predictions = []
120-
for i in range(raw_predictions.shape[0]):
120+
for i in range(len(raw_predictions)):
121121
batch_predictions.append(self.model.postproc_func(raw_predictions[i]))
122122
return batch_predictions
123123

@@ -233,13 +233,15 @@ def save_predictions(
233233
if self.patch_mode:
234234
save_paths = []
235235
for i, predictions in enumerate(processed_predictions["predictions"]):
236+
predictions_da = da.from_array(predictions, chunks=predictions.shape)
237+
236238
if isinstance(self.images[i], Path):
237239
output_path = save_path.parent / (self.images[i].stem + ".db")
238240
else:
239241
output_path = save_path.parent / (str(i) + ".db")
240242

241243
out_file = self.write_centroids_to_store(
242-
predictions,
244+
predictions_da,
243245
scale_factor=scale_factor,
244246
class_dict=class_dict,
245247
save_path=output_path,

0 commit comments

Comments
 (0)