1- """Test for Patch Predictor ."""
1+ """Test PatchPredictor ."""
22
33from __future__ import annotations
44
1111
1212import numpy as np
1313import torch
14+ import yaml
1415import zarr
1516from click .testing import CliRunner
1617
1718from tests .conftest import timed
1819from tiatoolbox import cli , logger , rcParam
1920from tiatoolbox .models import IOPatchPredictorConfig
21+ from tiatoolbox .models .architecture import fetch_pretrained_weights
2022from tiatoolbox .models .architecture .vanilla import CNNModel
2123from tiatoolbox .models .engine .patch_predictor import PatchPredictor
2224from tiatoolbox .utils import env_detection as toolbox_env
@@ -86,8 +88,7 @@ def test_io_config_delegation(remote_sample: Callable, tmp_path: Path) -> None:
8688 predictor = PatchPredictor (model = model , weights = None )
8789 kwargs = {
8890 "patch_input_shape" : [512 , 512 ],
89- "resolution" : 1.75 ,
90- "units" : "mpp" ,
91+ "input_resolutions" : [{"units" : "mpp" , "resolution" : 1.75 }],
9192 }
9293
9394 # test providing config / full input info for default models without weights
@@ -134,7 +135,7 @@ def test_io_config_delegation(remote_sample: Callable, tmp_path: Path) -> None:
134135
135136 predictor .run (
136137 images = [mini_wsi_svs ],
137- resolution = 1.99 ,
138+ input_resolutions = [{ "units" : "mpp" , "resolution" : 1.99 }] ,
138139 patch_mode = False ,
139140 save_dir = f"{ tmp_path } /dump" ,
140141 )
@@ -143,7 +144,7 @@ def test_io_config_delegation(remote_sample: Callable, tmp_path: Path) -> None:
143144
144145 predictor .run (
145146 images = [mini_wsi_svs ],
146- units = " baseline" ,
147+ input_resolutions = [{ "units" : " baseline", "resolution" : 1.0 }] ,
147148 patch_mode = False ,
148149 save_dir = f"{ tmp_path } /dump" ,
149150 )
@@ -152,8 +153,7 @@ def test_io_config_delegation(remote_sample: Callable, tmp_path: Path) -> None:
152153
153154 predictor .run (
154155 images = [mini_wsi_svs ],
155- units = "level" ,
156- resolution = 0 ,
156+ input_resolutions = [{"units" : "level" , "resolution" : 0 }],
157157 patch_mode = False ,
158158 save_dir = f"{ tmp_path } /dump" ,
159159 )
@@ -163,8 +163,7 @@ def test_io_config_delegation(remote_sample: Callable, tmp_path: Path) -> None:
163163
164164 predictor .run (
165165 images = [mini_wsi_svs ],
166- units = "power" ,
167- resolution = 20 ,
166+ input_resolutions = [{"units" : "power" , "resolution" : 20 }],
168167 patch_mode = False ,
169168 save_dir = f"{ tmp_path } /dump" ,
170169 )
@@ -262,8 +261,7 @@ def test_wsi_predictor_api(
262261 kwargs = {
263262 "patch_input_shape" : patch_size ,
264263 "stride_shape" : patch_size ,
265- "resolution" : 1.0 ,
266- "units" : "baseline" ,
264+ "input_resolutions" : [{"units" : "baseline" , "resolution" : 1.0 }],
267265 "save_dir" : save_dir ,
268266 }
269267 # ! add this test back once the read at `baseline` is fixed
@@ -646,6 +644,17 @@ def test_cli_model_multiple_file_mask(remote_sample: Callable, tmp_path: Path) -
646644 dir_path_masks = tmp_path .joinpath ("new_copies_masks" )
647645 dir_path_masks .mkdir ()
648646
647+ config = {
648+ "input_resolutions" : [{"units" : "mpp" , "resolution" : 0.5 }],
649+ "patch_input_shape" : [224 , 224 ],
650+ }
651+
652+ with Path .open (tmp_path .joinpath ("config.yaml" ), "w" ) as fptr :
653+ yaml .dump (config , fptr )
654+
655+ model = "alexnet-kather100k"
656+ weights = fetch_pretrained_weights (model )
657+
649658 try :
650659 dir_path .joinpath ("1_" + mini_wsi_svs .name ).symlink_to (mini_wsi_svs )
651660 dir_path .joinpath ("2_" + mini_wsi_svs .name ).symlink_to (mini_wsi_svs )
@@ -675,6 +684,12 @@ def test_cli_model_multiple_file_mask(remote_sample: Callable, tmp_path: Path) -
675684 str (False ),
676685 "--masks" ,
677686 str (dir_path_masks ),
687+ "--model" ,
688+ model ,
689+ "--weights" ,
690+ str (weights ),
691+ "--yaml-config-path" ,
692+ tmp_path / "config.yaml" ,
678693 "--output-path" ,
679694 str (tmp_path / "output" ),
680695 "--output-type" ,
0 commit comments