Skip to content

Commit a643ea6

Browse files
committed
♻️ Use input_resolutions instead of resolution
- Use `input_resolutions` instead of resolution to make engines outputs compatible with ioconfig. - Uses input resolution as a list of dictionaries on units and resolution.
1 parent 61624d4 commit a643ea6

File tree

5 files changed

+33
-31
lines changed

5 files changed

+33
-31
lines changed

tests/engines/test_engine_abc.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -530,7 +530,7 @@ def test_io_config_delegation(tmp_path: Path, caplog: pytest.LogCaptureFixture)
530530
patch_mode=True,
531531
save_dir=tmp_path / "dump",
532532
patch_input_shape=kwargs["patch_input_shape"],
533-
resolution=kwargs["resolution"],
533+
input_resolutions=kwargs["resolution"],
534534
units=kwargs["units"],
535535
)
536536
assert "provide a valid ModelIOConfigABC" in caplog.text
@@ -570,7 +570,7 @@ def test_io_config_delegation(tmp_path: Path, caplog: pytest.LogCaptureFixture)
570570
images=np.zeros((10, 224, 224, 3), dtype=np.uint8),
571571
patch_input_shape=(300, 300),
572572
stride_shape=(300, 300),
573-
resolution=1.99,
573+
input_resolutions=1.99,
574574
units="baseline",
575575
patch_mode=True,
576576
save_dir=f"{tmp_path}/dump",
@@ -585,7 +585,7 @@ def test_io_config_delegation(tmp_path: Path, caplog: pytest.LogCaptureFixture)
585585
images=np.zeros((10, 224, 224, 3), dtype=np.uint8),
586586
patch_input_shape=(300, 300),
587587
stride_shape=(300, 300),
588-
resolution=None,
588+
input_resolutions=None,
589589
units=None,
590590
patch_mode=True,
591591
save_dir=f"{tmp_path}/dump",
@@ -599,7 +599,7 @@ def test_io_config_delegation(tmp_path: Path, caplog: pytest.LogCaptureFixture)
599599
ioconfig=None,
600600
patch_input_shape=(300, 300),
601601
stride_shape=(300, 300),
602-
resolution=1.99,
602+
input_resolutions=1.99,
603603
units="baseline",
604604
)
605605

@@ -620,6 +620,6 @@ def test_io_config_delegation(tmp_path: Path, caplog: pytest.LogCaptureFixture)
620620
ioconfig=None,
621621
patch_input_shape=_kwargs["patch_input_shape"],
622622
stride_shape=(1, 1),
623-
resolution=_kwargs["resolution"],
623+
input_resolutions=_kwargs["resolution"],
624624
units=_kwargs["units"],
625625
)

tests/engines/test_patch_predictor.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""Test for Patch Predictor."""
1+
"""Test PatchPredictor."""
22

33
from __future__ import annotations
44

@@ -134,7 +134,7 @@ def test_io_config_delegation(remote_sample: Callable, tmp_path: Path) -> None:
134134

135135
predictor.run(
136136
images=[mini_wsi_svs],
137-
resolution=1.99,
137+
input_resolutions=1.99,
138138
patch_mode=False,
139139
save_dir=f"{tmp_path}/dump",
140140
)
@@ -153,7 +153,7 @@ def test_io_config_delegation(remote_sample: Callable, tmp_path: Path) -> None:
153153
predictor.run(
154154
images=[mini_wsi_svs],
155155
units="level",
156-
resolution=0,
156+
input_resolutions=0,
157157
patch_mode=False,
158158
save_dir=f"{tmp_path}/dump",
159159
)
@@ -164,7 +164,7 @@ def test_io_config_delegation(remote_sample: Callable, tmp_path: Path) -> None:
164164
predictor.run(
165165
images=[mini_wsi_svs],
166166
units="power",
167-
resolution=20,
167+
input_resolutions=20,
168168
patch_mode=False,
169169
save_dir=f"{tmp_path}/dump",
170170
)

tiatoolbox/models/engine/engine_abc.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535

3636
from tiatoolbox.annotation import AnnotationStore
3737
from tiatoolbox.models.models_abc import ModelABC
38-
from tiatoolbox.typing import IntPair, Resolution, Units
38+
from tiatoolbox.type_hints import IntPair, Resolution, Units
3939
from tiatoolbox.wsicore.wsireader import WSIReader
4040

4141

@@ -132,7 +132,7 @@ class EngineABCRunParams(TypedDict, total=False):
132132
Shape of patches input to the model as tuple of height and width (HW).
133133
Patches are requested at read resolution, not with respect to level 0,
134134
and must be positive.
135-
resolution (Resolution):
135+
input_resolutions (Resolution):
136136
Resolution used for reading the image. Please see
137137
:class:`WSIReader` for details.
138138
scale_factor (tuple[float, float]):
@@ -164,7 +164,7 @@ class EngineABCRunParams(TypedDict, total=False):
164164
num_post_proc_workers: int
165165
output_file: str
166166
patch_input_shape: IntPair
167-
resolution: Resolution
167+
input_resolutions: Resolution
168168
return_labels: bool
169169
scale_factor: tuple[float, float]
170170
stride_shape: IntPair
@@ -242,7 +242,7 @@ class EngineABC(ABC): # noqa: B024
242242
Runtime ioconfig.
243243
return_labels (bool):
244244
Whether to return the labels with the predictions.
245-
resolution (Resolution):
245+
input_resolutions (Resolution):
246246
Resolution used for reading the image. Please see
247247
:obj:`WSIReader` for details.
248248
units (Units):
@@ -283,7 +283,7 @@ class EngineABC(ABC): # noqa: B024
283283
Number of workers to postprocess the results of the model.
284284
return_labels (bool):
285285
Whether to return the output labels. Default value is False.
286-
resolution (Resolution):
286+
input_resolutions (Resolution):
287287
Resolution used for reading the image. Please see
288288
:class:`WSIReader` for details.
289289
When `patch_mode` is True, the input image patches are expected to be at
@@ -371,7 +371,7 @@ def __init__(
371371
self.num_loader_workers = num_loader_workers
372372
self.num_post_proc_workers = num_post_proc_workers
373373
self.patch_input_shape: IntPair | None = None
374-
self.resolution: Resolution | None = None
374+
self.input_resolutions: Resolution | None = None
375375
self.return_labels: bool = False
376376
self.stride_shape: IntPair | None = None
377377
self.units: Units | None = None
@@ -791,7 +791,7 @@ def _update_ioconfig(
791791
ioconfig: ModelIOConfigABC,
792792
patch_input_shape: IntPair,
793793
stride_shape: IntPair,
794-
resolution: Resolution,
794+
input_resolutions: Resolution,
795795
units: Units,
796796
) -> ModelIOConfigABC:
797797
"""Update IOConfig.
@@ -808,7 +808,7 @@ def _update_ioconfig(
808808
at requested read resolution, not with respect to
809809
level 0, and must be positive. If not provided,
810810
`stride_shape=patch_input_shape`.
811-
resolution (Resolution):
811+
input_resolutions (Resolution):
812812
Resolution used for reading the image. Please see
813813
:obj:`WSIReader` for details.
814814
units (Units):
@@ -820,7 +820,7 @@ def _update_ioconfig(
820820
"""
821821
config_flag = (
822822
patch_input_shape is None,
823-
resolution is None,
823+
input_resolutions is None,
824824
units is None,
825825
)
826826
if isinstance(ioconfig, ModelIOConfigABC):
@@ -845,15 +845,15 @@ def _update_ioconfig(
845845
ioconfig.patch_input_shape = patch_input_shape
846846
if stride_shape is not None:
847847
ioconfig.stride_shape = stride_shape
848-
if resolution is not None:
849-
ioconfig.input_resolutions[0]["resolution"] = resolution
848+
if input_resolutions is not None:
849+
ioconfig.input_resolutions[0]["resolution"] = input_resolutions
850850
if units is not None:
851851
ioconfig.input_resolutions[0]["units"] = units
852852

853853
return ioconfig
854854

855855
return ModelIOConfigABC(
856-
input_resolutions=[{"resolution": resolution, "units": units}],
856+
input_resolutions=[{"resolution": input_resolutions, "units": units}],
857857
patch_input_shape=patch_input_shape,
858858
stride_shape=stride_shape,
859859
output_resolutions=[],
@@ -955,7 +955,7 @@ def _update_run_params(
955955
ioconfig,
956956
self.patch_input_shape,
957957
self.stride_shape,
958-
self.resolution,
958+
self.input_resolutions,
959959
self.units,
960960
)
961961

tiatoolbox/models/engine/io_config.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import numpy as np
99

1010
if TYPE_CHECKING: # pragma: no cover
11-
from tiatoolbox.typing import Units
11+
from tiatoolbox.type_hints import Units
1212

1313

1414
@dataclass
@@ -69,6 +69,7 @@ def __post_init__(self: ModelIOConfigABC) -> None:
6969
self.stride_shape = self.patch_input_shape
7070

7171
self.resolution_unit = self.input_resolutions[0]["units"]
72+
self.highest_input_resolution = self.input_resolutions[0]["resolution"]
7273

7374
if self.resolution_unit == "mpp":
7475
self.highest_input_resolution = min(

tiatoolbox/models/engine/patch_predictor.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ class PredictorRunParams(EngineABCRunParams):
5858
Shape of patches input to the model as tuple of height and width (HW).
5959
Patches are requested at read resolution, not with respect to level 0,
6060
and must be positive.
61-
resolution (Resolution):
61+
input_resolutions (Resolution):
6262
Resolution used for reading the image. Please see
6363
:class:`WSIReader` for details.
6464
return_probabilities (bool):
@@ -239,7 +239,7 @@ class PatchPredictor(EngineABC):
239239
Runtime ioconfig.
240240
return_labels (bool):
241241
Whether to return the labels with the predictions.
242-
resolution (Resolution):
242+
input_resolutions (Resolution):
243243
Resolution used for reading the image. Please see
244244
:obj:`WSIReader` for details.
245245
units (Units):
@@ -280,7 +280,7 @@ class PatchPredictor(EngineABC):
280280
Number of workers to postprocess the results of the model.
281281
return_labels (bool):
282282
Whether to return the output labels. Default value is False.
283-
resolution (Resolution):
283+
input_resolutions (Resolution):
284284
Resolution used for reading the image. Please see
285285
:class:`WSIReader` for details.
286286
When `patch_mode` is True, the input image patches are expected to be at
@@ -301,27 +301,27 @@ class PatchPredictor(EngineABC):
301301
>>> # list of 2 image patches as input
302302
>>> data = ['path/img.svs', 'path/img.svs']
303303
>>> predictor = PatchPredictor(model="resnet18-kather100k")
304-
>>> output = predictor.run(data, mode='patch')
304+
>>> output = predictor.run(data, patch_mode=False)
305305
306306
>>> # array of list of 2 image patches as input
307307
>>> data = np.array([img1, img2])
308308
>>> predictor = PatchPredictor(model="resnet18-kather100k")
309-
>>> output = predictor.run(data, mode='patch')
309+
>>> output = predictor.run(data, patch_mode=True)
310310
311311
>>> # list of 2 image patch files as input
312312
>>> data = ['path/img.png', 'path/img.png']
313313
>>> predictor = PatchPredictor(model="resnet18-kather100k")
314-
>>> output = predictor.run(data, mode='patch')
314+
>>> output = predictor.run(data, patch_mode=True)
315315
316316
>>> # list of 2 image tile files as input
317317
>>> tile_file = ['path/tile1.png', 'path/tile2.png']
318318
>>> predictor = PatchPredictor(model="resnet18-kather100k")
319-
>>> output = predictor.run(tile_file, mode='tile')
319+
>>> output = predictor.run(tile_file, patch_mode=False)
320320
321321
>>> # list of 2 wsi files as input
322322
>>> wsi_file = ['path/wsi1.svs', 'path/wsi2.svs']
323323
>>> predictor = PatchPredictor(model="resnet18-kather100k")
324-
>>> output = predictor.run(wsi_file, mode='wsi')
324+
>>> output = predictor.run(wsi_file, patch_mode=False)
325325
326326
References:
327327
[1] Kather, Jakob Nikolas, et al. "Predicting survival from colorectal cancer
@@ -517,6 +517,7 @@ def run(
517517
518518
Examples:
519519
>>> wsis = ['wsi1.svs', 'wsi2.svs']
520+
>>> image_patches = [np.ndarray, np.ndarray]
520521
>>> class PatchPredictor(EngineABC):
521522
>>> # Define all Abstract methods.
522523
>>> ...

0 commit comments

Comments
 (0)