Skip to content

Commit 775aa0a

Browse files
committed
🔥 Remove cache_mode
1 parent 9f9a709 commit 775aa0a

File tree

5 files changed

+7
-75
lines changed

5 files changed

+7
-75
lines changed

tests/engines/test_engine_abc.py

Lines changed: 0 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -459,37 +459,6 @@ def test_patch_pred_zarr_store(tmp_path: pytest.TempPathFactory) -> NoReturn:
459459
)
460460

461461

462-
def test_cache_mode_patches(tmp_path: pytest.TempPathFactory) -> NoReturn:
463-
"""Test the caching mode."""
464-
save_dir = tmp_path / "patch_output"
465-
466-
eng = TestEngineABC(model="alexnet-kather100k")
467-
out = eng.run(
468-
images=np.zeros((10, 224, 224, 3), dtype=np.uint8),
469-
on_gpu=False,
470-
save_dir=save_dir,
471-
overwrite=True,
472-
cache_mode=True,
473-
)
474-
assert out.exists(), "Zarr output file does not exist"
475-
476-
output_file_name = "output2.zarr"
477-
cache_size = 4
478-
out = eng.run(
479-
images=np.zeros((10, 224, 224, 3), dtype=np.uint8),
480-
on_gpu=False,
481-
save_dir=save_dir,
482-
overwrite=True,
483-
cache_mode=True,
484-
cache_size=4,
485-
batch_size=8,
486-
output_file=output_file_name,
487-
)
488-
assert out.stem == output_file_name.split(".")[0]
489-
assert eng.batch_size == cache_size
490-
assert out.exists(), "Zarr output file does not exist"
491-
492-
493462
def test_get_dataloader(sample_svs: Path) -> None:
494463
"""Test the get_dataloader function."""
495464
eng = TestEngineABC(model="alexnet-kather100k")

tests/engines/test_semantic_segmentor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def test_semantic_segmentor_patches(remote_sample: Callable, tmp_path: Path) ->
3737

3838
inputs = [sample_image, sample_image]
3939

40-
assert segmentor.cache_mode is False
40+
assert not segmentor.patch_mode
4141

4242
output = segmentor.run(
4343
images=inputs,

tiatoolbox/models/engine/engine_abc.py

Lines changed: 1 addition & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -129,10 +129,6 @@ class EngineABCRunParams(TypedDict, total=False):
129129
Optional Keys:
130130
batch_size (int):
131131
Number of image patches per forward pass.
132-
cache_mode (bool):
133-
Whether to use caching for large datasets.
134-
cache_size (int):
135-
Number of patches to process in a batch when caching.
136132
class_dict (dict):
137133
Mapping of classification outputs to class names.
138134
device (str):
@@ -173,8 +169,6 @@ class EngineABCRunParams(TypedDict, total=False):
173169
"""
174170

175171
batch_size: int
176-
cache_mode: bool
177-
cache_size: int
178172
class_dict: dict
179173
device: str
180174
ioconfig: ModelIOConfigABC
@@ -277,10 +271,6 @@ class EngineABC(ABC): # noqa: B024
277271
`stride_shape=patch_input_shape`.
278272
batch_size (int):
279273
Number of images fed into the model each time.
280-
cache_mode (bool):
281-
Whether to use caching for large datasets.
282-
cache_size (int):
283-
Number of patches to process in a batch when caching.
284274
labels (list | None):
285275
Optional labels for input images.
286276
Only a single label per image is supported.
@@ -370,8 +360,6 @@ def __init__(
370360
)
371361
self._ioconfig = self.ioconfig # runtime ioconfig
372362
self.batch_size = batch_size
373-
self.cache_mode: bool = False
374-
self.cache_size: int = self.batch_size if self.batch_size else 10000
375363
self.labels: list | None = None
376364
self.num_loader_workers = num_loader_workers
377365
self.num_post_proc_workers = num_post_proc_workers
@@ -1065,22 +1053,13 @@ def _update_run_params(
10651053
self.drop_keys.append("label")
10661054

10671055
self.patch_mode = patch_mode
1068-
if not self.patch_mode:
1069-
self.cache_mode = True # if input is WSI run using cache mode.
1070-
1071-
if self.cache_mode and self.batch_size > self.cache_size:
1072-
self.batch_size = self.cache_size
10731056

10741057
self._validate_input_numbers(images=images, masks=masks, labels=labels)
10751058
if output_type.lower() not in ["dict", "zarr", "annotationstore"]:
10761059
msg = "output_type must be 'dict' or 'zarr' or 'annotationstore'."
10771060
raise TypeError(msg)
10781061

10791062
self.output_type = output_type
1080-
if self.cache_mode and output_type.lower() not in ["zarr", "annotationstore"]:
1081-
self.output_type = "zarr"
1082-
msg = "output_type has been updated to 'zarr' for cache_mode=True."
1083-
logger.info(msg)
10841063

10851064
if save_dir is not None and output_type.lower() not in [
10861065
"zarr",
@@ -1149,7 +1128,7 @@ def _run_patch_mode(
11491128
11501129
"""
11511130
save_path = None
1152-
if self.cache_mode or save_dir:
1131+
if save_dir:
11531132
output_file = Path(kwargs.get("output_file", "output.zarr"))
11541133
save_path = save_dir / (str(output_file.stem) + ".zarr")
11551134

tiatoolbox/models/engine/patch_predictor.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@
44

55
from typing import TYPE_CHECKING
66

7-
import dask.array as da
8-
from dask import delayed
97
from typing_extensions import Unpack
108

119
from .engine_abc import EngineABC, EngineABCRunParams
@@ -14,6 +12,7 @@
1412
import os
1513
from pathlib import Path
1614

15+
import dask.array as da
1716
import numpy as np
1817

1918
from tiatoolbox.annotation import AnnotationStore
@@ -190,10 +189,6 @@ class PatchPredictor(EngineABC):
190189
at requested read resolution, not with respect to
191190
level 0, and must be positive. If not provided,
192191
`stride_shape=patch_input_shape`.
193-
cache_mode (bool):
194-
Whether to use caching for large datasets.
195-
cache_size (int):
196-
Number of patches to process in a batch when caching.
197192
labels (list | None):
198193
Optional labels for input images.
199194
Only a single label per image is supported.
@@ -310,12 +305,9 @@ def post_process_patches(
310305
311306
"""
312307
_ = kwargs.get("return_probabilities")
313-
raw_predictions = delayed(self.model.postproc_func)(raw_predictions)
314-
return da.from_delayed(
315-
raw_predictions,
316-
shape=prediction_shape,
317-
dtype=prediction_dtype,
318-
)
308+
_ = prediction_shape
309+
_ = prediction_dtype
310+
return self.model.postproc_func(raw_predictions)
319311

320312
def post_process_wsi(
321313
self: PatchPredictor,

tiatoolbox/models/engine/semantic_segmentor.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -82,10 +82,6 @@ class SemanticSegmentorRunParams(PredictorRunParams, total=False):
8282
Attributes:
8383
batch_size (int):
8484
Number of image patches to feed to the model in a forward pass.
85-
cache_mode (bool):
86-
Whether to run the engine in cache mode. Recommended for large datasets.
87-
cache_size (int):
88-
Number of patches to process in a batch when cache_mode is True.
8985
class_dict (dict):
9086
Optional dictionary mapping classification outputs to class names.
9187
device (str):
@@ -223,10 +219,6 @@ class SemanticSegmentor(PatchPredictor):
223219
at requested read resolution, not with respect to
224220
level 0, and must be positive. If not provided,
225221
`stride_shape=patch_input_shape`.
226-
cache_mode (bool):
227-
Whether to use caching for large datasets.
228-
cache_size (int):
229-
Number of patches to process in a batch when caching.
230222
labels (list | None):
231223
Optional labels for input images.
232224
Only a single label per image is supported.
@@ -587,7 +579,7 @@ def save_predictions(
587579
)
588580
save_paths = out_file
589581

590-
if return_probabilities and self.cache_mode:
582+
if return_probabilities:
591583
zarr_save_path = save_path.parent.with_suffix(".zarr")
592584
msg = (
593585
f"Probability maps cannot be saved as AnnotationStore. "

0 commit comments

Comments
 (0)