Skip to content

Commit cf70c97

Browse files
authored
🔧 Add Zarr-based Caching and Memory-Aware Patch Merging for WSI Inference (#949)
This PR introduces Zarr-based caching and memory-efficient patch merging for semantic segmentation workflows. The changes aim to handle large whole-slide image (WSI) inference by implementing dynamic memory management and disk-based caching when system resources are constrained. Key changes: - Adds Zarr integration for intermediate canvas and count array storage during WSI inference - Implements memory threshold monitoring using psutil to trigger disk spilling when RAM usage exceeds limits - Refactors patch merging logic into modular helper functions for better maintainability
1 parent 708e706 commit cf70c97

File tree

4 files changed

+774
-225
lines changed

4 files changed

+774
-225
lines changed

‎tests/engines/test_semantic_segmentor.py‎

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import json
66
import sqlite3
77
from pathlib import Path
8-
from typing import Callable
8+
from typing import TYPE_CHECKING, Callable
99

1010
import numpy as np
1111
import torch
@@ -16,6 +16,9 @@
1616
from tiatoolbox.utils import env_detection as toolbox_env
1717
from tiatoolbox.utils.misc import imread
1818

19+
if TYPE_CHECKING:
20+
import pytest
21+
1922
device = "cuda" if toolbox_env.has_gpu() else "cpu"
2023

2124

@@ -212,14 +215,16 @@ def test_wsi_segmentor_zarr(
212215
remote_sample: Callable,
213216
sample_svs: Path,
214217
tmp_path: Path,
218+
caplog: pytest.LogCaptureFixture,
215219
) -> None:
216220
"""Test SemanticSegmentor for WSIs with zarr output."""
217221
wsi1_2k_2k_svs = Path(remote_sample("wsi1_2k_2k_svs"))
218222

219223
segmentor = SemanticSegmentor(
220224
model="fcn-tissue_mask",
221-
batch_size=32,
225+
batch_size=64,
222226
verbose=False,
227+
num_loader_workers=1,
223228
)
224229
# Return Probabilities is False
225230
output = segmentor.run(
@@ -229,15 +234,26 @@ def test_wsi_segmentor_zarr(
229234
device=device,
230235
patch_mode=False,
231236
save_dir=tmp_path / "wsi_out_check",
237+
batch_size=2,
232238
output_type="zarr",
239+
memory_threshold=1,
233240
)
234241

235242
output_ = zarr.open(output[sample_svs], mode="r")
236243
assert 0.17 < np.mean(output_["predictions"][:]) < 0.19
237244
assert "probabilities" not in output_
245+
assert "canvas" not in output_
246+
assert "count" not in output_
247+
assert "Current Memory usage:" in caplog.text
238248

239249
# Return Probabilities is True
240250
# Using small image for faster run
251+
segmentor = SemanticSegmentor(
252+
model="fcn-tissue_mask",
253+
batch_size=32,
254+
verbose=False,
255+
num_loader_workers=1,
256+
)
241257
segmentor.drop_keys = []
242258
output = segmentor.run(
243259
images=[sample_svs, wsi1_2k_2k_svs],

‎tiatoolbox/models/dataset/dataset_abc.py‎

Lines changed: 40 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -441,20 +441,7 @@ def __init__( # skipcq: PY-R1000
441441
patch_input_shape = np.array(patch_input_shape)
442442
stride_shape = np.array(stride_shape)
443443

444-
if (
445-
not np.issubdtype(patch_input_shape.dtype, np.integer)
446-
or np.size(patch_input_shape) > 2 # noqa: PLR2004
447-
or np.any(patch_input_shape < 0)
448-
):
449-
msg = f"Invalid `patch_input_shape` value {patch_input_shape}."
450-
raise ValueError(msg)
451-
if (
452-
not np.issubdtype(stride_shape.dtype, np.integer)
453-
or np.size(stride_shape) > 2 # noqa: PLR2004
454-
or np.any(stride_shape < 0)
455-
):
456-
msg = f"Invalid `stride_shape` value {stride_shape}."
457-
raise ValueError(msg)
444+
_validate_patch_stride_shape(patch_input_shape, stride_shape)
458445

459446
self.preproc_func = preproc_func
460447
img_path = Path(img_path)
@@ -475,6 +462,7 @@ def __init__( # skipcq: PY-R1000
475462
stride_shape=stride_shape[::-1],
476463
patch_output_shape=patch_output_shape,
477464
)
465+
self.full_outputs = self.outputs
478466
else:
479467
self.inputs = PatchExtractor.get_coordinates(
480468
image_shape=wsi_shape,
@@ -510,6 +498,7 @@ def __init__( # skipcq: PY-R1000
510498
)
511499
self.inputs = self.inputs[selected]
512500
if hasattr(self, "outputs"):
501+
self.full_outputs = self.outputs # Full list of outputs
513502
self.outputs = self.outputs[selected]
514503

515504
if len(self.inputs) == 0:
@@ -639,3 +628,40 @@ def __getitem__(self: PatchDataset, idx: int) -> dict:
639628
return data
640629

641630
return data
631+
632+
633+
def _validate_patch_stride_shape(
634+
patch_input_shape: np.ndarray, stride_shape: np.ndarray
635+
) -> None:
636+
"""Validate patch and stride shape inputs for semantic segmentation.
637+
638+
Checks that both `patch_input_shape` and `stride_shape` are integer arrays of
639+
length ≤ 2 and contain non-negative values. Raises a ValueError if any
640+
condition fails.
641+
642+
Parameters:
643+
patch_input_shape (np.ndarray):
644+
Shape of the input patch (e.g., height, width).
645+
stride_shape (np.ndarray):
646+
Stride dimensions used for patch extraction.
647+
648+
Raises:
649+
ValueError:
650+
If either input is not a valid integer array of appropriate
651+
shape and values.
652+
653+
"""
654+
if (
655+
not np.issubdtype(patch_input_shape.dtype, np.integer)
656+
or np.size(patch_input_shape) > 2 # noqa: PLR2004
657+
or np.any(patch_input_shape < 0)
658+
):
659+
msg = f"Invalid `patch_input_shape` value {patch_input_shape}."
660+
raise ValueError(msg)
661+
if (
662+
not np.issubdtype(stride_shape.dtype, np.integer)
663+
or np.size(stride_shape) > 2 # noqa: PLR2004
664+
or np.any(stride_shape < 0)
665+
):
666+
msg = f"Invalid `stride_shape` value {stride_shape}."
667+
raise ValueError(msg)

‎tiatoolbox/models/engine/engine_abc.py‎

Lines changed: 76 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
import dask.array as da
4444
import numpy as np
4545
import torch
46+
import zarr
4647
from dask import compute
4748
from dask.diagnostics import ProgressBar
4849
from torch import nn
@@ -71,58 +72,6 @@
7172
from tiatoolbox.wsicore.wsireader import WSIReader
7273

7374

74-
def prepare_engines_save_dir(
75-
save_dir: str | Path | None,
76-
*,
77-
patch_mode: bool,
78-
overwrite: bool = False,
79-
) -> Path | None:
80-
"""Create or validate the save directory for engine outputs.
81-
82-
Args:
83-
save_dir (str | Path | None):
84-
Path to the output directory.
85-
patch_mode (bool):
86-
Whether the input is treated as patches.
87-
overwrite (bool):
88-
Whether to overwrite existing directory. Default is False.
89-
90-
Returns:
91-
Path | None:
92-
Path to the output directory if created or validated, else None.
93-
94-
Raises:
95-
OSError:
96-
If patch_mode is False and save_dir is not provided.
97-
98-
"""
99-
if patch_mode:
100-
if save_dir is not None:
101-
save_dir = Path(save_dir)
102-
save_dir.mkdir(parents=True, exist_ok=overwrite)
103-
return save_dir
104-
return None
105-
106-
if save_dir is None:
107-
msg = (
108-
"Input WSIs detected but no save directory provided. "
109-
"Please provide a 'save_dir'."
110-
)
111-
raise OSError(msg)
112-
113-
logger.info(
114-
"When providing multiple whole slide images, "
115-
"the outputs will be saved and the locations of outputs "
116-
"will be returned to the calling function when `run()` "
117-
"finishes successfully."
118-
)
119-
120-
save_dir = Path(save_dir)
121-
save_dir.mkdir(parents=True, exist_ok=overwrite)
122-
123-
return save_dir
124-
125-
12675
class EngineABCRunParams(TypedDict, total=False):
12776
"""Parameters for configuring the :func:`EngineABC.run()` method.
12877
@@ -180,6 +129,9 @@ class EngineABCRunParams(TypedDict, total=False):
180129
return_labels: bool
181130
scale_factor: tuple[float, float]
182131
stride_shape: IntPair
132+
memory_threshold: int
133+
da_length_threshold: int
134+
auto_get_mask: bool
183135
verbose: bool
184136

185137

@@ -432,6 +384,7 @@ def get_dataloader(
432384
ioconfig: ModelIOConfigABC | None = None,
433385
*,
434386
patch_mode: bool = True,
387+
auto_get_mask: bool = True,
435388
) -> torch.utils.data.DataLoader:
436389
"""Pre-process images and masks and return a DataLoader for inference.
437390
@@ -450,6 +403,12 @@ def get_dataloader(
450403
IO configuration object specifying patch size, stride, and resolution.
451404
patch_mode (bool):
452405
Whether to treat input as patches (`True`) or WSIs (`False`).
406+
auto_get_mask (bool):
407+
Auto generates tissue mask using `wsireader.tissue_mask()` when
408+
patch_mode is False.
409+
If set to `True`, this mask processes only the tissue regions in the
410+
image. If `False` all the patches in the image are processed.
411+
Default is `True`.
453412
454413
Returns:
455414
torch.utils.data.DataLoader:
@@ -468,6 +427,7 @@ def get_dataloader(
468427
stride_shape=ioconfig.stride_shape,
469428
resolution=ioconfig.input_resolutions[0]["resolution"],
470429
units=ioconfig.input_resolutions[0]["units"],
430+
auto_get_mask=auto_get_mask,
471431
)
472432

473433
dataset.preproc_func = self.model.preproc_func
@@ -692,6 +652,11 @@ def save_predictions(
692652
with ProgressBar():
693653
compute(*write_tasks)
694654

655+
zarr_group = zarr.open(save_path, mode="r+")
656+
for key in self.drop_keys:
657+
if key in zarr_group:
658+
del zarr_group[key]
659+
695660
return save_path
696661

697662
values_to_compute = [processed_predictions[k] for k in keys_to_compute]
@@ -726,6 +691,7 @@ def save_predictions(
726691
def infer_wsi(
727692
self: EngineABC,
728693
dataloader: DataLoader,
694+
save_path: Path,
729695
**kwargs: Unpack[EngineABCRunParams],
730696
) -> dict:
731697
"""Run model inference on a whole slide image (WSI).
@@ -737,6 +703,9 @@ def infer_wsi(
737703
Args:
738704
dataloader (DataLoader):
739705
PyTorch DataLoader configured for WSI processing.
706+
save_path (Path):
707+
Path to save the intermediate output. The intermediate output is saved
708+
in a zarr file.
740709
**kwargs (EngineABCRunParams):
741710
Additional runtime parameters used during inference.
742711
@@ -746,6 +715,7 @@ def infer_wsi(
746715
747716
"""
748717
_ = kwargs.get("patch_mode", False)
718+
_ = save_path
749719
return self.infer_patches(
750720
dataloader=dataloader,
751721
return_coordinates=True,
@@ -1267,12 +1237,14 @@ def _run_wsi_mode(
12671237
masks=mask,
12681238
patch_mode=False,
12691239
ioconfig=self._ioconfig,
1240+
auto_get_mask=kwargs.get("auto_get_mask", True),
12701241
)
12711242

12721243
scale_factor = self._calculate_scale_factor(dataloader=self.dataloader)
12731244

12741245
raw_predictions = self.infer_wsi(
12751246
dataloader=self.dataloader,
1247+
save_path=save_path[image],
12761248
**kwargs,
12771249
)
12781250

@@ -1403,3 +1375,55 @@ def run(
14031375
save_dir=save_dir,
14041376
**kwargs,
14051377
)
1378+
1379+
1380+
def prepare_engines_save_dir(
1381+
save_dir: str | Path | None,
1382+
*,
1383+
patch_mode: bool,
1384+
overwrite: bool = False,
1385+
) -> Path | None:
1386+
"""Create or validate the save directory for engine outputs.
1387+
1388+
Args:
1389+
save_dir (str | Path | None):
1390+
Path to the output directory.
1391+
patch_mode (bool):
1392+
Whether the input is treated as patches.
1393+
overwrite (bool):
1394+
Whether to overwrite existing directory. Default is False.
1395+
1396+
Returns:
1397+
Path | None:
1398+
Path to the output directory if created or validated, else None.
1399+
1400+
Raises:
1401+
OSError:
1402+
If patch_mode is False and save_dir is not provided.
1403+
1404+
"""
1405+
if patch_mode:
1406+
if save_dir is not None:
1407+
save_dir = Path(save_dir)
1408+
save_dir.mkdir(parents=True, exist_ok=overwrite)
1409+
return save_dir
1410+
return None
1411+
1412+
if save_dir is None:
1413+
msg = (
1414+
"Input WSIs detected but no save directory provided. "
1415+
"Please provide a 'save_dir'."
1416+
)
1417+
raise OSError(msg)
1418+
1419+
logger.info(
1420+
"When providing multiple whole slide images, "
1421+
"the outputs will be saved and the locations of outputs "
1422+
"will be returned to the calling function when `run()` "
1423+
"finishes successfully."
1424+
)
1425+
1426+
save_dir = Path(save_dir)
1427+
save_dir.mkdir(parents=True, exist_ok=overwrite)
1428+
1429+
return save_dir

0 commit comments

Comments
 (0)