Skip to content

Commit b3b5642

Browse files
committed
replace cache_dir with temp dir
1 parent 0be5bfc commit b3b5642

File tree

4 files changed

+29
-37
lines changed

4 files changed

+29
-37
lines changed

tests/engines/test_nucleus_detection_engine.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,6 @@ def test_nucleus_detector_wsi(remote_sample: Callable, track_tmp_path: Path) ->
256256
memory_threshold=50,
257257
images=[mini_wsi_svs],
258258
save_dir=save_dir,
259-
cache_dir=save_dir,
260259
overwrite=True,
261260
batch_size=8,
262261
class_dict={0: "test_nucleus"},
@@ -278,7 +277,6 @@ def test_nucleus_detector_wsi(remote_sample: Callable, track_tmp_path: Path) ->
278277
memory_threshold=50,
279278
images=[mini_wsi_svs],
280279
save_dir=save_dir,
281-
cache_dir=save_dir,
282280
overwrite=True,
283281
batch_size=8,
284282
)

tiatoolbox/cli/common.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -81,19 +81,6 @@ def cli_output_path(
8181
)
8282

8383

84-
def cli_cache_path(
85-
usage_help: str = "Path to cache directory to save the cache.",
86-
default: str | None = None,
87-
) -> Callable:
88-
"""Enables --cache-path option for cli."""
89-
return click.option(
90-
"--cache-path",
91-
help=add_default_to_usage_help(usage_help, default=default),
92-
type=str,
93-
default=default,
94-
)
95-
96-
9784
def cli_output_file(
9885
usage_help: str = "Filename for saving output (e.g., '.zarr' or '.db').",
9986
default: str | None = None,

tiatoolbox/cli/nucleus_detector.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,11 @@
22

33
from __future__ import annotations
44

5-
import shutil
65
from typing import TYPE_CHECKING
76

87
from tiatoolbox.cli.common import (
98
cli_auto_get_mask,
109
cli_batch_size,
11-
cli_cache_path,
1210
cli_class_dict,
1311
cli_device,
1412
cli_file_type,
@@ -51,10 +49,6 @@
5149
usage_help="Output directory where model prediction will be saved.",
5250
default="nucleus_detection",
5351
)
54-
@cli_cache_path(
55-
usage_help="Directory to use for caching intermediate files.",
56-
default="nucleus_detection",
57-
)
5852
@cli_output_file(default=None)
5953
@cli_file_type(
6054
default="*.png, *.jpg, *.jpeg, *.tif, *.tiff, *.svs, *.ndpi, *.jp2, *.mrxs",
@@ -96,7 +90,6 @@ def nucleus_detector(
9690
output_resolutions: list[dict],
9791
masks: str | None,
9892
output_path: str,
99-
cache_path: str,
10093
patch_input_shape: IntPair | None,
10194
patch_output_shape: tuple[int, int] | None,
10295
stride_shape: IntPair | None,
@@ -158,7 +151,6 @@ def nucleus_detector(
158151
ioconfig=ioconfig,
159152
device=device,
160153
save_dir=output_path,
161-
cache_dir=cache_path,
162154
output_type=output_type,
163155
return_probabilities=return_probabilities,
164156
auto_get_mask=auto_get_mask,
@@ -174,5 +166,3 @@ def nucleus_detector(
174166
overwrite=overwrite,
175167
verbose=verbose,
176168
)
177-
178-
shutil.rmtree(cache_path)

tiatoolbox/models/engine/nucleus_detector.py

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
from __future__ import annotations
4848

4949
import shutil
50+
import tempfile
5051
from pathlib import Path
5152
from typing import TYPE_CHECKING
5253

@@ -116,8 +117,6 @@ class NucleusDetectorRunParams(SemanticSegmentorRunParams, total=False):
116117
postproc_tile_shape (tuple[int, int]):
117118
Tile shape (height, width) used during post-processing
118119
(in pixels) to control rechunking behavior.
119-
cache_dir (str or os.PathLike):
120-
Directory for caching intermediate results during WSI processing.
121120
return_labels (bool):
122121
Whether to return labels with predictions.
123122
return_probabilities (bool):
@@ -136,7 +135,6 @@ class NucleusDetectorRunParams(SemanticSegmentorRunParams, total=False):
136135
threshold_abs: float
137136
threshold_rel: float
138137
postproc_tile_shape: IntPair
139-
cache_dir: str | os.PathLike
140138

141139

142140
class NucleusDetector(SemanticSegmentor):
@@ -478,17 +476,18 @@ def post_process_wsi(
478476
depth_w=depth_w,
479477
)
480478

481-
logger.info("Computing and saving centroid maps to cache as zarr.")
482-
save_dir = kwargs.get("cache_dir", "./tmp/")
483-
save_path = Path(save_dir) / "detection_maps"
484-
if save_path.exists():
485-
shutil.rmtree(save_path)
486-
487-
task = centroid_maps.to_zarr(url=save_path, compute=False, object_codec=None)
479+
logger.info("Computing and saving centroid maps to temporary zarr file.")
480+
temp_zarr_file = tempfile.TemporaryDirectory(
481+
prefix="tiatoolbox_nucleus_detector_", suffix=".zarr"
482+
)
483+
logger.info("Temporary zarr file created at: %s", temp_zarr_file.name)
484+
task = centroid_maps.to_zarr(
485+
url=temp_zarr_file.name, compute=False, object_codec=None
486+
)
488487
with ProgressBar():
489488
compute(task)
490489

491-
centroid_maps = da.from_zarr(save_path)
490+
centroid_maps = da.from_zarr(temp_zarr_file.name)
492491

493492
return self._centroid_maps_to_detection_arrays(centroid_maps)
494493

@@ -1112,7 +1111,7 @@ class names.
11121111
11131112
11141113
"""
1115-
return super().run(
1114+
output = super().run(
11161115
images=images,
11171116
masks=masks,
11181117
input_resolutions=input_resolutions,
@@ -1124,3 +1123,21 @@ class names.
11241123
output_type=output_type,
11251124
**kwargs,
11261125
)
1126+
1127+
if not patch_mode:
1128+
# Clean up temporary zarr directory after WSI processing
1129+
# It should have been already deleted, but check anyway
1130+
temp_dir = Path(tempfile.gettempdir())
1131+
if temp_dir.exists():
1132+
# find file starting with 'tiatoolbox_nucleus_detector_'
1133+
# and ending with '.zarr'
1134+
for item in temp_dir.iterdir():
1135+
if item.name.startswith(
1136+
"tiatoolbox_nucleus_detector_"
1137+
) and item.name.endswith(".zarr"):
1138+
shutil.rmtree(item)
1139+
logger.info(
1140+
"Temporary zarr directory %s has been removed.", item
1141+
)
1142+
1143+
return output

0 commit comments

Comments
 (0)