Skip to content

Commit fea5be9

Browse files
committed
✨ Add support for "power" in "scale_factor".
Signed-off-by: Shan E Ahmed Raza <[email protected]>
1 parent dd712ad commit fea5be9

File tree

2 files changed

+58
-20
lines changed

2 files changed

+58
-20
lines changed

tests/engines/test_patch_predictor.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,17 @@ def test_io_config_delegation(remote_sample: Callable, tmp_path: Path) -> None:
112112
assert predictor._ioconfig.input_resolutions[0]["resolution"] == 0
113113
shutil.rmtree(tmp_path / "dump", ignore_errors=True)
114114

115+
predictor.run(
116+
images=[mini_wsi_svs],
117+
units="power",
118+
resolution=20,
119+
patch_mode=False,
120+
save_dir=f"{tmp_path}/dump",
121+
)
122+
assert predictor._ioconfig.input_resolutions[0]["units"] == "power"
123+
assert predictor._ioconfig.input_resolutions[0]["resolution"] == 20
124+
shutil.rmtree(tmp_path / "dump", ignore_errors=True)
125+
115126

116127
def test_patch_predictor_api(
117128
sample_patch1: Path,

tiatoolbox/models/engine/engine_abc.py

Lines changed: 47 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1029,6 +1029,52 @@ def _run_patch_mode(
10291029
**kwargs,
10301030
)
10311031

1032+
def _calculate_scale_factor(
1033+
self: EngineABC, dataloader: DataLoader
1034+
) -> float | tuple[float, float]:
1035+
"""Calculates scale factor for final output.
1036+
1037+
Uses the dataloader resolution and the WSI resolution to calculate scale
1038+
factor for final WSI output.
1039+
1040+
Args:
1041+
dataloader (DataLoader):
1042+
Dataloader for the current run.
1043+
1044+
Returns:
1045+
scale_factor (float | tuple[float, float]):
1046+
Scale factor for final output.
1047+
1048+
"""
1049+
# get units and resolution from dataloader.
1050+
dataloader_units = dataloader.dataset.units
1051+
dataloader_resolution = dataloader.dataset.resolution
1052+
1053+
# if dataloader units is baseline slide resolution is 1.0.
1054+
# in this case dataloader resolution / slide resolution will be
1055+
# equal to dataloader resolution.
1056+
1057+
if dataloader_units in ["mpp", "level", "objective_power"]:
1058+
wsimeta_dict = dataloader.dataset.reader.info.as_dict()
1059+
1060+
if dataloader_units == "mpp":
1061+
slide_resolution = wsimeta_dict[dataloader_units]
1062+
scale_factor = np.divide(slide_resolution, dataloader_resolution)
1063+
return scale_factor[0], scale_factor[1]
1064+
1065+
if dataloader_units == "level":
1066+
downsample_ratio = wsimeta_dict["level_downsamples"][dataloader_resolution]
1067+
return 1.0 / downsample_ratio, 1.0 / downsample_ratio
1068+
1069+
if dataloader_resolution == "objective_power":
1070+
slide_objective_power = wsimeta_dict["power"]
1071+
return (
1072+
dataloader_resolution / slide_objective_power,
1073+
dataloader_resolution / slide_objective_power,
1074+
)
1075+
1076+
return dataloader_resolution
1077+
10321078
def _run_wsi_mode(
10331079
self: EngineABC,
10341080
output_type: str,
@@ -1061,26 +1107,7 @@ def _run_wsi_mode(
10611107
ioconfig=self._ioconfig,
10621108
)
10631109

1064-
# get units and resolution from dataloader.
1065-
dataloader_units = dataloader.dataset.units
1066-
dataloader_resolution = dataloader.dataset.resolution
1067-
1068-
# if dataloader units is baseline slide resolution is 1.0.
1069-
# in this case dataloader resolution / slide resolution will be
1070-
# equal to dataloader resolution.
1071-
scale_factor = dataloader_resolution
1072-
1073-
if dataloader_units == "mpp":
1074-
wsimeta_dict = dataloader.dataset.reader.info.as_dict()
1075-
slide_resolution = wsimeta_dict[dataloader_units]
1076-
scale_factor = tuple(np.divide(slide_resolution, dataloader_resolution))
1077-
1078-
if dataloader_units == "level":
1079-
wsimeta_dict = dataloader.dataset.reader.info.as_dict()
1080-
downsample_ratio = wsimeta_dict["level_downsamples"][
1081-
dataloader_resolution
1082-
]
1083-
scale_factor = (1.0 / downsample_ratio, 1.0 / downsample_ratio)
1110+
scale_factor = self._calculate_scale_factor(dataloader=dataloader)
10841111

10851112
raw_predictions = self.infer_wsi(
10861113
dataloader=dataloader,

0 commit comments

Comments
 (0)