1- """Test for Patch Classifier ."""
1+ """Test for Patch Predictor ."""
22
33from __future__ import annotations
44
1414from click .testing import CliRunner
1515
1616from tiatoolbox import cli
17- from tiatoolbox .models import IOPatchClassifierConfig
17+ from tiatoolbox .models import IOPatchPredictorConfig
1818from tiatoolbox .models .architecture .vanilla import CNNModel
19- from tiatoolbox .models .engine .patch_classifier import PatchClassifier
19+ from tiatoolbox .models .engine .patch_predictor import PatchPredictor
2020from tiatoolbox .utils import env_detection as toolbox_env
2121from tiatoolbox .utils .misc import download_data , get_zarr_array , imwrite
2222
2626device = "cuda" if toolbox_env .has_gpu () else "cpu"
2727
2828
29- def _test_classifier_output (
29+ def _test_predictor_output (
3030 inputs : list ,
3131 model : str ,
3232 probabilities_check : list | None = None ,
@@ -37,13 +37,13 @@ def _test_classifier_output(
3737 """Test the predictions of multiple models included in tiatoolbox."""
3838 cache_mode = None if tmp_path is None else True
3939 save_dir = None if tmp_path is None else tmp_path / "output"
40- classifier = PatchClassifier (
40+ predictor = PatchPredictor (
4141 model = model ,
4242 batch_size = 32 ,
4343 verbose = False ,
4444 )
4545 # don't run test on GPU
46- output = classifier .run (
46+ output = predictor .run (
4747 inputs ,
4848 return_labels = False ,
4949 device = device ,
@@ -81,28 +81,28 @@ def test_io_config_delegation(remote_sample: Callable, tmp_path: Path) -> None:
8181 """Test for delegating args to io config."""
8282 mini_wsi_svs = Path (remote_sample ("wsi2_4k_4k_svs" ))
8383 model = CNNModel ("resnet50" )
84- classifier = PatchClassifier (model = model , weights = None )
84+ predictor = PatchPredictor (model = model , weights = None )
8585 kwargs = {
8686 "patch_input_shape" : [512 , 512 ],
8787 "resolution" : 1.75 ,
8888 "units" : "mpp" ,
8989 }
9090
9191 # test providing config / full input info for default models without weights
92- ioconfig = IOPatchClassifierConfig (
92+ ioconfig = IOPatchPredictorConfig (
9393 patch_input_shape = (512 , 512 ),
9494 stride_shape = (256 , 256 ),
9595 input_resolutions = [{"resolution" : 1.35 , "units" : "mpp" }],
9696 )
97- classifier .run (
97+ predictor .run (
9898 images = [mini_wsi_svs ],
9999 ioconfig = ioconfig ,
100100 patch_mode = False ,
101101 save_dir = f"{ tmp_path } /dump" ,
102102 )
103103 shutil .rmtree (tmp_path / "dump" , ignore_errors = True )
104104
105- classifier .run (
105+ predictor .run (
106106 images = [mini_wsi_svs ],
107107 patch_mode = False ,
108108 save_dir = f"{ tmp_path } /dump" ,
@@ -111,80 +111,80 @@ def test_io_config_delegation(remote_sample: Callable, tmp_path: Path) -> None:
111111 shutil .rmtree (tmp_path / "dump" , ignore_errors = True )
112112
113113 # test overwriting pretrained ioconfig
114- classifier = PatchClassifier (model = "resnet18-kather100k" , batch_size = 1 )
115- classifier .run (
114+ predictor = PatchPredictor (model = "resnet18-kather100k" , batch_size = 1 )
115+ predictor .run (
116116 images = [mini_wsi_svs ],
117117 patch_input_shape = (300 , 300 ),
118118 patch_mode = False ,
119119 save_dir = f"{ tmp_path } /dump" ,
120120 )
121- assert classifier ._ioconfig .patch_input_shape == (300 , 300 )
121+ assert predictor ._ioconfig .patch_input_shape == (300 , 300 )
122122 shutil .rmtree (tmp_path / "dump" , ignore_errors = True )
123123
124- classifier .run (
124+ predictor .run (
125125 images = [mini_wsi_svs ],
126126 stride_shape = (300 , 300 ),
127127 patch_mode = False ,
128128 save_dir = f"{ tmp_path } /dump" ,
129129 )
130- assert classifier ._ioconfig .stride_shape == (300 , 300 )
130+ assert predictor ._ioconfig .stride_shape == (300 , 300 )
131131 shutil .rmtree (tmp_path / "dump" , ignore_errors = True )
132132
133- classifier .run (
133+ predictor .run (
134134 images = [mini_wsi_svs ],
135135 resolution = 1.99 ,
136136 patch_mode = False ,
137137 save_dir = f"{ tmp_path } /dump" ,
138138 )
139- assert classifier ._ioconfig .input_resolutions [0 ]["resolution" ] == 1.99
139+ assert predictor ._ioconfig .input_resolutions [0 ]["resolution" ] == 1.99
140140 shutil .rmtree (tmp_path / "dump" , ignore_errors = True )
141141
142- classifier .run (
142+ predictor .run (
143143 images = [mini_wsi_svs ],
144144 units = "baseline" ,
145145 patch_mode = False ,
146146 save_dir = f"{ tmp_path } /dump" ,
147147 )
148- assert classifier ._ioconfig .input_resolutions [0 ]["units" ] == "baseline"
148+ assert predictor ._ioconfig .input_resolutions [0 ]["units" ] == "baseline"
149149 shutil .rmtree (tmp_path / "dump" , ignore_errors = True )
150150
151- classifier .run (
151+ predictor .run (
152152 images = [mini_wsi_svs ],
153153 units = "level" ,
154154 resolution = 0 ,
155155 patch_mode = False ,
156156 save_dir = f"{ tmp_path } /dump" ,
157157 )
158- assert classifier ._ioconfig .input_resolutions [0 ]["units" ] == "level"
159- assert classifier ._ioconfig .input_resolutions [0 ]["resolution" ] == 0
158+ assert predictor ._ioconfig .input_resolutions [0 ]["units" ] == "level"
159+ assert predictor ._ioconfig .input_resolutions [0 ]["resolution" ] == 0
160160 shutil .rmtree (tmp_path / "dump" , ignore_errors = True )
161161
162- classifier .run (
162+ predictor .run (
163163 images = [mini_wsi_svs ],
164164 units = "power" ,
165165 resolution = 20 ,
166166 patch_mode = False ,
167167 save_dir = f"{ tmp_path } /dump" ,
168168 )
169- assert classifier ._ioconfig .input_resolutions [0 ]["units" ] == "power"
170- assert classifier ._ioconfig .input_resolutions [0 ]["resolution" ] == 20
169+ assert predictor ._ioconfig .input_resolutions [0 ]["units" ] == "power"
170+ assert predictor ._ioconfig .input_resolutions [0 ]["resolution" ] == 20
171171 shutil .rmtree (tmp_path / "dump" , ignore_errors = True )
172172
173173
174- def test_patch_classifier_api (
174+ def test_patch_predictor_api (
175175 sample_patch1 : Path ,
176176 sample_patch2 : Path ,
177177 tmp_path : Path ,
178178) -> None :
179- """Test Patch Classifier API."""
179+ """Test PatchPredictor API."""
180180 save_dir_path = tmp_path
181181
182182 # convert to pathlib Path to prevent reader complaint
183183 inputs = [Path (sample_patch1 ), Path (sample_patch2 )]
184- classifier = PatchClassifier (model = "resnet18-kather100k" , batch_size = 1 )
184+ predictor = PatchPredictor (model = "resnet18-kather100k" , batch_size = 1 )
185185 # don't run test on GPU
186186 # Default run
187- output = classifier .run (
187+ output = predictor .run (
188188 inputs ,
189189 device = "cpu" ,
190190 )
@@ -193,7 +193,7 @@ def test_patch_classifier_api(
193193 shutil .rmtree (save_dir_path , ignore_errors = True )
194194
195195 # whether to return labels
196- output = classifier .run (
196+ output = predictor .run (
197197 inputs ,
198198 labels = ["1" , "a" ],
199199 return_labels = True ,
@@ -217,17 +217,17 @@ def test_patch_classifier_api(
217217
218218 download_data (pretrained_weights_url , pretrained_weights )
219219
220- classifier = PatchClassifier (
220+ predictor = PatchPredictor (
221221 model = "resnet18-kather100k" ,
222222 weights = pretrained_weights ,
223223 batch_size = 1 ,
224224 )
225- ioconfig = classifier .ioconfig
225+ ioconfig = predictor .ioconfig
226226
227227 # --- test different using user model
228228 model = CNNModel (backbone = "resnet18" , num_classes = 9 )
229229 # test prediction
230- predictor = PatchClassifier (model = model , batch_size = 1 , verbose = False )
230+ predictor = PatchPredictor (model = model , batch_size = 1 , verbose = False )
231231 output = predictor .run (
232232 inputs ,
233233 labels = [1 , 2 ],
@@ -239,7 +239,7 @@ def test_patch_classifier_api(
239239 assert output ["labels" ].tolist () == [1 , 2 ]
240240
241241
242- def test_wsi_classifier_api (
242+ def test_wsi_predictor_api (
243243 sample_wsi_dict : dict ,
244244 tmp_path : Path ,
245245) -> None :
@@ -252,7 +252,7 @@ def test_wsi_classifier_api(
252252 mini_wsi_msk = Path (sample_wsi_dict ["wsi2_4k_4k_msk" ])
253253
254254 patch_size = np .array ([224 , 224 ])
255- predictor = PatchClassifier (model = "resnet18-kather100k" , batch_size = 32 )
255+ predictor = PatchPredictor (model = "resnet18-kather100k" , batch_size = 32 )
256256
257257 save_dir = f"{ save_dir_path } /model_wsi_output"
258258
@@ -290,7 +290,7 @@ def test_wsi_classifier_api(
290290 shutil .rmtree (_kwargs ["save_dir" ], ignore_errors = True )
291291
292292
293- def test_patch_classifier_kather100k_output (
293+ def test_patch_predictor_kather100k_output (
294294 sample_patch1 : Path ,
295295 sample_patch2 : Path ,
296296 tmp_path : Path ,
@@ -317,7 +317,7 @@ def test_patch_classifier_kather100k_output(
317317 "googlenet-kather100k" : [1.0 , 0.9999639987945557 ],
318318 }
319319 for model , expected_prob in pretrained_info .items ():
320- _test_classifier_output (
320+ _test_predictor_output (
321321 inputs ,
322322 model ,
323323 probabilities_check = expected_prob ,
@@ -326,7 +326,7 @@ def test_patch_classifier_kather100k_output(
326326
327327 # cache mode
328328 for model , expected_prob in pretrained_info .items ():
329- _test_classifier_output (
329+ _test_predictor_output (
330330 inputs ,
331331 model ,
332332 probabilities_check = expected_prob ,
@@ -376,19 +376,19 @@ def _validate_probabilities(output: list | dict | zarr.group) -> bool:
376376 return np .all (predictions [:][0 :5 ] == [7 , 3 , 2 , 3 , 3 ])
377377
378378
379- def test_wsi_classifier_zarr (
379+ def test_wsi_predictor_zarr (
380380 sample_wsi_dict : dict , tmp_path : Path , caplog : pytest .LogCaptureFixture
381381) -> None :
382- """Test normal run of patch classifier for WSIs."""
382+ """Test normal run of patch predictor for WSIs."""
383383 mini_wsi_svs = Path (sample_wsi_dict ["wsi2_4k_4k_svs" ])
384384
385- classifier = PatchClassifier (
385+ predictor = PatchPredictor (
386386 model = "alexnet-kather100k" ,
387387 batch_size = 32 ,
388388 verbose = False ,
389389 )
390390 # don't run test on GPU
391- output = classifier .run (
391+ output = predictor .run (
392392 images = [mini_wsi_svs ],
393393 return_probabilities = True ,
394394 return_labels = False ,
@@ -412,7 +412,7 @@ def test_wsi_classifier_zarr(
412412 assert _validate_probabilities (output = output_ )
413413 assert "Output file saved at " in caplog .text
414414
415- output = classifier .run (
415+ output = predictor .run (
416416 images = [mini_wsi_svs ],
417417 return_probabilities = False ,
418418 return_labels = False ,
@@ -436,21 +436,21 @@ def test_wsi_classifier_zarr(
436436 assert "Output file saved at " in caplog .text
437437
438438
439- def test_patch_classifier_patch_mode_annotation_store (
439+ def test_patch_predictor_patch_mode_annotation_store (
440440 sample_patch1 : Path ,
441441 sample_patch2 : Path ,
442442 tmp_path : Path ,
443443) -> None :
444444 """Test the output of patch classification models on Kather100K dataset."""
445445 inputs = [Path (sample_patch1 ), Path (sample_patch2 )]
446446
447- classifier = PatchClassifier (
447+ predictor = PatchPredictor (
448448 model = "alexnet-kather100k" ,
449449 batch_size = 32 ,
450450 verbose = False ,
451451 )
452452 # don't run test on GPU
453- output = classifier .run (
453+ output = predictor .run (
454454 images = inputs ,
455455 return_probabilities = True ,
456456 return_labels = False ,
@@ -467,21 +467,21 @@ def test_patch_classifier_patch_mode_annotation_store(
467467 assert np .all (np .array (output ["probabilities" ]) >= 0 )
468468
469469
470- def test_patch_classifier_patch_mode_no_probabilities (
470+ def test_patch_predictor_patch_mode_no_probabilities (
471471 sample_patch1 : Path ,
472472 sample_patch2 : Path ,
473473 tmp_path : Path ,
474474) -> None :
475475 """Test the output of patch classification models on Kather100K dataset."""
476476 inputs = [Path (sample_patch1 ), Path (sample_patch2 )]
477477
478- classifier = PatchClassifier (
478+ predictor = PatchPredictor (
479479 model = "alexnet-kather100k" ,
480480 batch_size = 32 ,
481481 verbose = False ,
482482 )
483483
484- output = classifier .run (
484+ output = predictor .run (
485485 images = inputs ,
486486 return_probabilities = False ,
487487 return_labels = False ,
@@ -492,7 +492,7 @@ def test_patch_classifier_patch_mode_no_probabilities(
492492 assert "probabilities" not in output
493493
494494 # don't run test on GPU
495- output = classifier .run (
495+ output = predictor .run (
496496 images = inputs ,
497497 return_probabilities = False ,
498498 return_labels = False ,
@@ -518,7 +518,7 @@ def test_engine_run_wsi_annotation_store(
518518 mini_wsi_svs = Path (sample_wsi_dict ["wsi2_4k_4k_svs" ])
519519 mini_wsi_msk = Path (sample_wsi_dict ["wsi2_4k_4k_msk" ])
520520
521- eng = PatchClassifier (model = "alexnet-kather100k" )
521+ eng = PatchPredictor (model = "alexnet-kather100k" )
522522
523523 patch_size = np .array ([224 , 224 ])
524524 save_dir = f"{ tmp_path } /model_wsi_output"
@@ -567,7 +567,7 @@ def test_cli_model_single_file(sample_svs: Path, tmp_path: Path) -> None:
567567 models_wsi_result = runner .invoke (
568568 cli .main ,
569569 [
570- "patch-classifier " ,
570+ "patch-predictor " ,
571571 "--img-input" ,
572572 str (sample_svs ),
573573 "--patch-mode" ,
@@ -618,7 +618,7 @@ def test_cli_model_multiple_file_mask(remote_sample: Callable, tmp_path: Path) -
618618 models_tiles_result = runner .invoke (
619619 cli .main ,
620620 [
621- "patch-classifier " ,
621+ "patch-predictor " ,
622622 "--img-input" ,
623623 str (dir_path ),
624624 "--patch-mode" ,
0 commit comments