Skip to content

Commit 606a2c0

Browse files
committed
🎨 Update code structure for dataloader_units
1 parent da0ce4f commit 606a2c0

File tree

3 files changed

+35
-9
lines changed

3 files changed

+35
-9
lines changed

tests/engines/test_patch_predictor.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,17 @@ def test_io_config_delegation(remote_sample: Callable, tmp_path: Path) -> None:
101101
assert predictor._ioconfig.input_resolutions[0]["units"] == "baseline"
102102
shutil.rmtree(tmp_path / "dump", ignore_errors=True)
103103

104+
predictor.run(
105+
images=[mini_wsi_svs],
106+
units="level",
107+
resolution=0,
108+
patch_mode=False,
109+
save_dir=f"{tmp_path}/dump",
110+
)
111+
assert predictor._ioconfig.input_resolutions[0]["units"] == "level"
112+
assert predictor._ioconfig.input_resolutions[0]["resolution"] == 0
113+
shutil.rmtree(tmp_path / "dump", ignore_errors=True)
114+
104115

105116
def test_patch_predictor_api(
106117
sample_patch1: Path,

tiatoolbox/models/engine/engine_abc.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from __future__ import annotations
44

55
import copy
6+
import shutil
67
from abc import ABC, abstractmethod
78
from pathlib import Path
89
from typing import TYPE_CHECKING, TypedDict
@@ -639,7 +640,7 @@ def post_process_patches(
639640

640641
def save_predictions(
641642
self: EngineABC,
642-
processed_predictions: dict,
643+
processed_predictions: dict | Path,
643644
output_type: str,
644645
save_dir: Path | None = None,
645646
**kwargs: dict,
@@ -681,16 +682,23 @@ def save_predictions(
681682
# class_dict set from kwargs
682683
class_dict = kwargs.get("class_dict")
683684

685+
processed_predictions_path: str | Path | None = None
686+
684687
# Need to add support for zarr conversion.
685688
if self.cache_mode:
689+
processed_predictions_path = processed_predictions
686690
processed_predictions = zarr.open(processed_predictions, mode="r")
687691

688-
return dict_to_store(
692+
out_file = dict_to_store(
689693
processed_predictions,
690694
scale_factor,
691695
class_dict,
692696
save_path,
693697
)
698+
if processed_predictions_path is not None:
699+
shutil.rmtree(processed_predictions_path)
700+
701+
return out_file
694702

695703
return (
696704
dict_to_zarr(
@@ -1057,15 +1065,22 @@ def _run_wsi_mode(
10571065
dataloader_units = dataloader.dataset.units
10581066
dataloader_resolution = dataloader.dataset.resolution
10591067

1060-
slide_resolution = (1.0, 1.0)
1061-
if dataloader_units != "baseline":
1068+
# if dataloader units is baseline slide resolution is 1.0.
1069+
# in this case dataloader resolution / slide resolution will be
1070+
# equal to dataloader resolution.
1071+
scale_factor = dataloader_resolution
1072+
1073+
if dataloader_units == "mpp":
10621074
wsimeta_dict = dataloader.dataset.reader.info.as_dict()
10631075
slide_resolution = wsimeta_dict[dataloader_units]
1076+
scale_factor = tuple(np.divide(slide_resolution, dataloader_resolution))
10641077

1065-
scale_factor = tuple(np.divide(slide_resolution, dataloader_resolution))
1066-
1067-
if dataloader_units != "mpp":
1068-
scale_factor = tuple(np.divide(dataloader_resolution, slide_resolution))
1078+
if dataloader_units == "level":
1079+
wsimeta_dict = dataloader.dataset.reader.info.as_dict()
1080+
downsample_ratio = wsimeta_dict["level_downsamples"][
1081+
dataloader_resolution
1082+
]
1083+
scale_factor = (1.0 / downsample_ratio, 1.0 / downsample_ratio)
10691084

10701085
raw_predictions = self.infer_wsi(
10711086
dataloader=dataloader,

tiatoolbox/utils/misc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1309,7 +1309,7 @@ def dict_to_store(
13091309

13101310
# if a save director is provided, then dump store into a file
13111311
if save_path:
1312-
# ensure parent directory exisits
1312+
# ensure parent directory exists
13131313
save_path.parent.absolute().mkdir(parents=True, exist_ok=True)
13141314
# ensure proper db extension
13151315
save_path = save_path.parent.absolute() / (save_path.stem + ".db")

0 commit comments

Comments
 (0)