Skip to content

Commit 8c89911

Browse files
authored
♻️ Use input_resolutions instead of resolution in Engine Params (#917)
* ♻️ Use `input_resolutions` instead of resolution - Use `input_resolutions` instead of resolution to make engines outputs compatible with ioconfig. - Uses input resolution as a list of dictionaries on units and resolution. * ♻️ Use `input_resolutions` instead of resolution - Use `input_resolutions` instead of resolution to make engines outputs compatible with ioconfig. - Uses input resolution as a list of dictionaries on units and resolution. * ✅ Add test to cli.
1 parent 0ba148c commit 8c89911

File tree

10 files changed

+129
-131
lines changed

10 files changed

+129
-131
lines changed

tests/engines/test_engine_abc.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -521,17 +521,15 @@ def test_io_config_delegation(tmp_path: Path, caplog: pytest.LogCaptureFixture)
521521

522522
kwargs = {
523523
"patch_input_shape": [512, 512],
524-
"resolution": 1.75,
525-
"units": "mpp",
524+
"input_resolutions": [{"units": "mpp", "resolution": 1.75}],
526525
}
527526
with caplog.at_level(logging.WARNING):
528527
eng.run(
529528
np.zeros((10, 224, 224, 3)),
530529
patch_mode=True,
531530
save_dir=tmp_path / "dump",
532531
patch_input_shape=kwargs["patch_input_shape"],
533-
resolution=kwargs["resolution"],
534-
units=kwargs["units"],
532+
input_resolutions=kwargs["input_resolutions"],
535533
)
536534
assert "provide a valid ModelIOConfigABC" in caplog.text
537535
shutil.rmtree(tmp_path / "dump", ignore_errors=True)
@@ -570,8 +568,7 @@ def test_io_config_delegation(tmp_path: Path, caplog: pytest.LogCaptureFixture)
570568
images=np.zeros((10, 224, 224, 3), dtype=np.uint8),
571569
patch_input_shape=(300, 300),
572570
stride_shape=(300, 300),
573-
resolution=1.99,
574-
units="baseline",
571+
input_resolutions=[{"units": "baseline", "resolution": 1.99}],
575572
patch_mode=True,
576573
save_dir=f"{tmp_path}/dump",
577574
)
@@ -585,8 +582,7 @@ def test_io_config_delegation(tmp_path: Path, caplog: pytest.LogCaptureFixture)
585582
images=np.zeros((10, 224, 224, 3), dtype=np.uint8),
586583
patch_input_shape=(300, 300),
587584
stride_shape=(300, 300),
588-
resolution=None,
589-
units=None,
585+
input_resolutions=None,
590586
patch_mode=True,
591587
save_dir=f"{tmp_path}/dump",
592588
)
@@ -599,8 +595,7 @@ def test_io_config_delegation(tmp_path: Path, caplog: pytest.LogCaptureFixture)
599595
ioconfig=None,
600596
patch_input_shape=(300, 300),
601597
stride_shape=(300, 300),
602-
resolution=1.99,
603-
units="baseline",
598+
input_resolutions=[{"units": "baseline", "resolution": 1.99}],
604599
)
605600

606601
assert _ioconfig.patch_input_shape == (300, 300)
@@ -614,12 +609,11 @@ def test_io_config_delegation(tmp_path: Path, caplog: pytest.LogCaptureFixture)
614609
with pytest.raises(
615610
ValueError,
616611
match=r".*Must provide either `ioconfig` or "
617-
r"`patch_input_shape`, `resolution`, and `units`*",
612+
r"`patch_input_shape` and `input_resolutions`*",
618613
):
619614
eng._update_ioconfig(
620615
ioconfig=None,
621616
patch_input_shape=_kwargs["patch_input_shape"],
622617
stride_shape=(1, 1),
623-
resolution=_kwargs["resolution"],
624-
units=_kwargs["units"],
618+
input_resolutions=_kwargs["input_resolutions"],
625619
)

tests/engines/test_patch_predictor.py

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""Test for Patch Predictor."""
1+
"""Test PatchPredictor."""
22

33
from __future__ import annotations
44

@@ -11,12 +11,14 @@
1111

1212
import numpy as np
1313
import torch
14+
import yaml
1415
import zarr
1516
from click.testing import CliRunner
1617

1718
from tests.conftest import timed
1819
from tiatoolbox import cli, logger, rcParam
1920
from tiatoolbox.models import IOPatchPredictorConfig
21+
from tiatoolbox.models.architecture import fetch_pretrained_weights
2022
from tiatoolbox.models.architecture.vanilla import CNNModel
2123
from tiatoolbox.models.engine.patch_predictor import PatchPredictor
2224
from tiatoolbox.utils import env_detection as toolbox_env
@@ -86,8 +88,7 @@ def test_io_config_delegation(remote_sample: Callable, tmp_path: Path) -> None:
8688
predictor = PatchPredictor(model=model, weights=None)
8789
kwargs = {
8890
"patch_input_shape": [512, 512],
89-
"resolution": 1.75,
90-
"units": "mpp",
91+
"input_resolutions": [{"units": "mpp", "resolution": 1.75}],
9192
}
9293

9394
# test providing config / full input info for default models without weights
@@ -134,7 +135,7 @@ def test_io_config_delegation(remote_sample: Callable, tmp_path: Path) -> None:
134135

135136
predictor.run(
136137
images=[mini_wsi_svs],
137-
resolution=1.99,
138+
input_resolutions=[{"units": "mpp", "resolution": 1.99}],
138139
patch_mode=False,
139140
save_dir=f"{tmp_path}/dump",
140141
)
@@ -143,7 +144,7 @@ def test_io_config_delegation(remote_sample: Callable, tmp_path: Path) -> None:
143144

144145
predictor.run(
145146
images=[mini_wsi_svs],
146-
units="baseline",
147+
input_resolutions=[{"units": "baseline", "resolution": 1.0}],
147148
patch_mode=False,
148149
save_dir=f"{tmp_path}/dump",
149150
)
@@ -152,8 +153,7 @@ def test_io_config_delegation(remote_sample: Callable, tmp_path: Path) -> None:
152153

153154
predictor.run(
154155
images=[mini_wsi_svs],
155-
units="level",
156-
resolution=0,
156+
input_resolutions=[{"units": "level", "resolution": 0}],
157157
patch_mode=False,
158158
save_dir=f"{tmp_path}/dump",
159159
)
@@ -163,8 +163,7 @@ def test_io_config_delegation(remote_sample: Callable, tmp_path: Path) -> None:
163163

164164
predictor.run(
165165
images=[mini_wsi_svs],
166-
units="power",
167-
resolution=20,
166+
input_resolutions=[{"units": "power", "resolution": 20}],
168167
patch_mode=False,
169168
save_dir=f"{tmp_path}/dump",
170169
)
@@ -262,8 +261,7 @@ def test_wsi_predictor_api(
262261
kwargs = {
263262
"patch_input_shape": patch_size,
264263
"stride_shape": patch_size,
265-
"resolution": 1.0,
266-
"units": "baseline",
264+
"input_resolutions": [{"units": "baseline", "resolution": 1.0}],
267265
"save_dir": save_dir,
268266
}
269267
# ! add this test back once the read at `baseline` is fixed
@@ -646,6 +644,17 @@ def test_cli_model_multiple_file_mask(remote_sample: Callable, tmp_path: Path) -
646644
dir_path_masks = tmp_path.joinpath("new_copies_masks")
647645
dir_path_masks.mkdir()
648646

647+
config = {
648+
"input_resolutions": [{"units": "mpp", "resolution": 0.5}],
649+
"patch_input_shape": [224, 224],
650+
}
651+
652+
with Path.open(tmp_path.joinpath("config.yaml"), "w") as fptr:
653+
yaml.dump(config, fptr)
654+
655+
model = "alexnet-kather100k"
656+
weights = fetch_pretrained_weights(model)
657+
649658
try:
650659
dir_path.joinpath("1_" + mini_wsi_svs.name).symlink_to(mini_wsi_svs)
651660
dir_path.joinpath("2_" + mini_wsi_svs.name).symlink_to(mini_wsi_svs)
@@ -675,6 +684,12 @@ def test_cli_model_multiple_file_mask(remote_sample: Callable, tmp_path: Path) -
675684
str(False),
676685
"--masks",
677686
str(dir_path_masks),
687+
"--model",
688+
model,
689+
"--weights",
690+
str(weights),
691+
"--yaml-config-path",
692+
tmp_path / "config.yaml",
678693
"--output-path",
679694
str(tmp_path / "output"),
680695
"--output-type",

tiatoolbox/cli/common.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -619,17 +619,17 @@ def prepare_model_cli(
619619
tiatoolbox_cli = TIAToolboxCLI()
620620

621621

622-
def prepare_ioconfig_seg(
623-
segment_config_class: type[IOConfigABC],
622+
def prepare_ioconfig(
623+
config_class: type[IOConfigABC],
624624
pretrained_weights: str | Path | None,
625625
yaml_config_path: str | Path,
626626
) -> IOConfigABC | None:
627-
"""Prepare ioconfig for segmentation."""
627+
"""Prepare ioconfig for CLI."""
628628
import yaml
629629

630630
if pretrained_weights is not None:
631631
with Path(yaml_config_path).open() as registry_handle:
632632
ioconfig = yaml.safe_load(registry_handle)
633-
return segment_config_class(**ioconfig)
633+
return config_class(**ioconfig)
634634

635635
return None

tiatoolbox/cli/nucleus_instance_segment.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
cli_pretrained_weights,
2020
cli_verbose,
2121
cli_yaml_config_path,
22-
prepare_ioconfig_seg,
22+
prepare_ioconfig,
2323
prepare_model_cli,
2424
tiatoolbox_cli,
2525
)
@@ -77,7 +77,7 @@ def nucleus_instance_segment(
7777
file_types=file_types,
7878
)
7979

80-
ioconfig = prepare_ioconfig_seg(
80+
ioconfig = prepare_ioconfig(
8181
IOInstanceSegmentorConfig,
8282
pretrained_weights,
8383
yaml_config_path,

tiatoolbox/cli/patch_predictor.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,12 @@
1313
cli_output_path,
1414
cli_output_type,
1515
cli_patch_mode,
16-
cli_resolution,
1716
cli_return_labels,
1817
cli_return_probabilities,
19-
cli_units,
2018
cli_verbose,
2119
cli_weights,
20+
cli_yaml_config_path,
21+
prepare_ioconfig,
2222
prepare_model_cli,
2323
tiatoolbox_cli,
2424
)
@@ -37,8 +37,7 @@
3737
@cli_weights()
3838
@cli_device(default="cpu")
3939
@cli_batch_size(default=1)
40-
@cli_resolution(default=0.5)
41-
@cli_units(default="mpp")
40+
@cli_yaml_config_path()
4241
@cli_masks(default=None)
4342
@cli_num_loader_workers(default=0)
4443
@cli_output_type(
@@ -56,8 +55,7 @@ def patch_predictor(
5655
masks: str | None,
5756
output_path: str,
5857
batch_size: int,
59-
resolution: float,
60-
units: str,
58+
yaml_config_path: str,
6159
num_loader_workers: int,
6260
device: str,
6361
output_type: str,
@@ -68,6 +66,7 @@ def patch_predictor(
6866
verbose: bool,
6967
) -> None:
7068
"""Process an image/directory of input images with a patch classification CNN."""
69+
from tiatoolbox.models.engine.io_config import IOPatchPredictorConfig
7170
from tiatoolbox.models.engine.patch_predictor import PatchPredictor
7271

7372
files_all, masks_all, output_path = prepare_model_cli(
@@ -85,12 +84,17 @@ def patch_predictor(
8584
verbose=verbose,
8685
)
8786

87+
ioconfig = prepare_ioconfig(
88+
IOPatchPredictorConfig,
89+
pretrained_weights=weights,
90+
yaml_config_path=yaml_config_path,
91+
)
92+
8893
_ = predictor.run(
8994
images=files_all,
9095
masks=masks_all,
9196
patch_mode=patch_mode,
92-
resolution=resolution,
93-
units=units,
97+
ioconfig=ioconfig,
9498
device=device,
9599
save_dir=output_path,
96100
output_type=output_type,

tiatoolbox/cli/semantic_segment.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
cli_pretrained_weights,
1818
cli_verbose,
1919
cli_yaml_config_path,
20-
prepare_ioconfig_seg,
20+
prepare_ioconfig,
2121
prepare_model_cli,
2222
tiatoolbox_cli,
2323
)
@@ -71,7 +71,7 @@ def semantic_segment(
7171
file_types=file_types,
7272
)
7373

74-
ioconfig = prepare_ioconfig_seg(
74+
ioconfig = prepare_ioconfig(
7575
IOSegmentorConfig,
7676
pretrained_weights,
7777
yaml_config_path,

tiatoolbox/models/architecture/utils.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,17 @@
33
from __future__ import annotations
44

55
import sys
6+
from typing import TYPE_CHECKING
67

78
import numpy as np
89
import torch
910
from torch import nn
1011

1112
from tiatoolbox import logger
1213

14+
if TYPE_CHECKING: # pragma: no cover
15+
from tiatoolbox.models.models_abc import ModelABC
16+
1317

1418
def is_torch_compile_compatible() -> bool:
1519
"""Check if the current GPU is compatible with torch-compile.
@@ -45,10 +49,10 @@ def is_torch_compile_compatible() -> bool:
4549

4650

4751
def compile_model(
48-
model: nn.Module | None = None,
52+
model: nn.Module | ModelABC | None = None,
4953
*,
5054
mode: str = "default",
51-
) -> nn.Module:
55+
) -> torch.nn.Module | ModelABC:
5256
"""A decorator to compile a model using torch-compile.
5357
5458
Args:
@@ -67,7 +71,7 @@ def compile_model(
6771
CUDA graphs
6872
6973
Returns:
70-
torch.nn.Module:
74+
torch.nn.Module or ModelABC:
7175
Compiled model.
7276
7377
"""

0 commit comments

Comments
 (0)