Skip to content

Commit b1bff40

Browse files
committed
♻️ Refactor all code and remove patch classifier. Replace it with patch predictor.
1 parent 2197ca0 commit b1bff40

File tree

8 files changed

+373
-670
lines changed

8 files changed

+373
-670
lines changed

tests/engines/test_patch_predictor.py

Lines changed: 53 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""Test for Patch Classifier."""
1+
"""Test for Patch Predictor."""
22

33
from __future__ import annotations
44

@@ -14,9 +14,9 @@
1414
from click.testing import CliRunner
1515

1616
from tiatoolbox import cli
17-
from tiatoolbox.models import IOPatchClassifierConfig
17+
from tiatoolbox.models import IOPatchPredictorConfig
1818
from tiatoolbox.models.architecture.vanilla import CNNModel
19-
from tiatoolbox.models.engine.patch_classifier import PatchClassifier
19+
from tiatoolbox.models.engine.patch_predictor import PatchPredictor
2020
from tiatoolbox.utils import env_detection as toolbox_env
2121
from tiatoolbox.utils.misc import download_data, get_zarr_array, imwrite
2222

@@ -26,7 +26,7 @@
2626
device = "cuda" if toolbox_env.has_gpu() else "cpu"
2727

2828

29-
def _test_classifier_output(
29+
def _test_predictor_output(
3030
inputs: list,
3131
model: str,
3232
probabilities_check: list | None = None,
@@ -37,13 +37,13 @@ def _test_classifier_output(
3737
"""Test the predictions of multiple models included in tiatoolbox."""
3838
cache_mode = None if tmp_path is None else True
3939
save_dir = None if tmp_path is None else tmp_path / "output"
40-
classifier = PatchClassifier(
40+
predictor = PatchPredictor(
4141
model=model,
4242
batch_size=32,
4343
verbose=False,
4444
)
4545
# don't run test on GPU
46-
output = classifier.run(
46+
output = predictor.run(
4747
inputs,
4848
return_labels=False,
4949
device=device,
@@ -81,28 +81,28 @@ def test_io_config_delegation(remote_sample: Callable, tmp_path: Path) -> None:
8181
"""Test for delegating args to io config."""
8282
mini_wsi_svs = Path(remote_sample("wsi2_4k_4k_svs"))
8383
model = CNNModel("resnet50")
84-
classifier = PatchClassifier(model=model, weights=None)
84+
predictor = PatchPredictor(model=model, weights=None)
8585
kwargs = {
8686
"patch_input_shape": [512, 512],
8787
"resolution": 1.75,
8888
"units": "mpp",
8989
}
9090

9191
# test providing config / full input info for default models without weights
92-
ioconfig = IOPatchClassifierConfig(
92+
ioconfig = IOPatchPredictorConfig(
9393
patch_input_shape=(512, 512),
9494
stride_shape=(256, 256),
9595
input_resolutions=[{"resolution": 1.35, "units": "mpp"}],
9696
)
97-
classifier.run(
97+
predictor.run(
9898
images=[mini_wsi_svs],
9999
ioconfig=ioconfig,
100100
patch_mode=False,
101101
save_dir=f"{tmp_path}/dump",
102102
)
103103
shutil.rmtree(tmp_path / "dump", ignore_errors=True)
104104

105-
classifier.run(
105+
predictor.run(
106106
images=[mini_wsi_svs],
107107
patch_mode=False,
108108
save_dir=f"{tmp_path}/dump",
@@ -111,80 +111,80 @@ def test_io_config_delegation(remote_sample: Callable, tmp_path: Path) -> None:
111111
shutil.rmtree(tmp_path / "dump", ignore_errors=True)
112112

113113
# test overwriting pretrained ioconfig
114-
classifier = PatchClassifier(model="resnet18-kather100k", batch_size=1)
115-
classifier.run(
114+
predictor = PatchPredictor(model="resnet18-kather100k", batch_size=1)
115+
predictor.run(
116116
images=[mini_wsi_svs],
117117
patch_input_shape=(300, 300),
118118
patch_mode=False,
119119
save_dir=f"{tmp_path}/dump",
120120
)
121-
assert classifier._ioconfig.patch_input_shape == (300, 300)
121+
assert predictor._ioconfig.patch_input_shape == (300, 300)
122122
shutil.rmtree(tmp_path / "dump", ignore_errors=True)
123123

124-
classifier.run(
124+
predictor.run(
125125
images=[mini_wsi_svs],
126126
stride_shape=(300, 300),
127127
patch_mode=False,
128128
save_dir=f"{tmp_path}/dump",
129129
)
130-
assert classifier._ioconfig.stride_shape == (300, 300)
130+
assert predictor._ioconfig.stride_shape == (300, 300)
131131
shutil.rmtree(tmp_path / "dump", ignore_errors=True)
132132

133-
classifier.run(
133+
predictor.run(
134134
images=[mini_wsi_svs],
135135
resolution=1.99,
136136
patch_mode=False,
137137
save_dir=f"{tmp_path}/dump",
138138
)
139-
assert classifier._ioconfig.input_resolutions[0]["resolution"] == 1.99
139+
assert predictor._ioconfig.input_resolutions[0]["resolution"] == 1.99
140140
shutil.rmtree(tmp_path / "dump", ignore_errors=True)
141141

142-
classifier.run(
142+
predictor.run(
143143
images=[mini_wsi_svs],
144144
units="baseline",
145145
patch_mode=False,
146146
save_dir=f"{tmp_path}/dump",
147147
)
148-
assert classifier._ioconfig.input_resolutions[0]["units"] == "baseline"
148+
assert predictor._ioconfig.input_resolutions[0]["units"] == "baseline"
149149
shutil.rmtree(tmp_path / "dump", ignore_errors=True)
150150

151-
classifier.run(
151+
predictor.run(
152152
images=[mini_wsi_svs],
153153
units="level",
154154
resolution=0,
155155
patch_mode=False,
156156
save_dir=f"{tmp_path}/dump",
157157
)
158-
assert classifier._ioconfig.input_resolutions[0]["units"] == "level"
159-
assert classifier._ioconfig.input_resolutions[0]["resolution"] == 0
158+
assert predictor._ioconfig.input_resolutions[0]["units"] == "level"
159+
assert predictor._ioconfig.input_resolutions[0]["resolution"] == 0
160160
shutil.rmtree(tmp_path / "dump", ignore_errors=True)
161161

162-
classifier.run(
162+
predictor.run(
163163
images=[mini_wsi_svs],
164164
units="power",
165165
resolution=20,
166166
patch_mode=False,
167167
save_dir=f"{tmp_path}/dump",
168168
)
169-
assert classifier._ioconfig.input_resolutions[0]["units"] == "power"
170-
assert classifier._ioconfig.input_resolutions[0]["resolution"] == 20
169+
assert predictor._ioconfig.input_resolutions[0]["units"] == "power"
170+
assert predictor._ioconfig.input_resolutions[0]["resolution"] == 20
171171
shutil.rmtree(tmp_path / "dump", ignore_errors=True)
172172

173173

174-
def test_patch_classifier_api(
174+
def test_patch_predictor_api(
175175
sample_patch1: Path,
176176
sample_patch2: Path,
177177
tmp_path: Path,
178178
) -> None:
179-
"""Test Patch Classifier API."""
179+
"""Test PatchPredictor API."""
180180
save_dir_path = tmp_path
181181

182182
# convert to pathlib Path to prevent reader complaint
183183
inputs = [Path(sample_patch1), Path(sample_patch2)]
184-
classifier = PatchClassifier(model="resnet18-kather100k", batch_size=1)
184+
predictor = PatchPredictor(model="resnet18-kather100k", batch_size=1)
185185
# don't run test on GPU
186186
# Default run
187-
output = classifier.run(
187+
output = predictor.run(
188188
inputs,
189189
device="cpu",
190190
)
@@ -193,7 +193,7 @@ def test_patch_classifier_api(
193193
shutil.rmtree(save_dir_path, ignore_errors=True)
194194

195195
# whether to return labels
196-
output = classifier.run(
196+
output = predictor.run(
197197
inputs,
198198
labels=["1", "a"],
199199
return_labels=True,
@@ -217,17 +217,17 @@ def test_patch_classifier_api(
217217

218218
download_data(pretrained_weights_url, pretrained_weights)
219219

220-
classifier = PatchClassifier(
220+
predictor = PatchPredictor(
221221
model="resnet18-kather100k",
222222
weights=pretrained_weights,
223223
batch_size=1,
224224
)
225-
ioconfig = classifier.ioconfig
225+
ioconfig = predictor.ioconfig
226226

227227
# --- test different using user model
228228
model = CNNModel(backbone="resnet18", num_classes=9)
229229
# test prediction
230-
predictor = PatchClassifier(model=model, batch_size=1, verbose=False)
230+
predictor = PatchPredictor(model=model, batch_size=1, verbose=False)
231231
output = predictor.run(
232232
inputs,
233233
labels=[1, 2],
@@ -239,7 +239,7 @@ def test_patch_classifier_api(
239239
assert output["labels"].tolist() == [1, 2]
240240

241241

242-
def test_wsi_classifier_api(
242+
def test_wsi_predictor_api(
243243
sample_wsi_dict: dict,
244244
tmp_path: Path,
245245
) -> None:
@@ -252,7 +252,7 @@ def test_wsi_classifier_api(
252252
mini_wsi_msk = Path(sample_wsi_dict["wsi2_4k_4k_msk"])
253253

254254
patch_size = np.array([224, 224])
255-
predictor = PatchClassifier(model="resnet18-kather100k", batch_size=32)
255+
predictor = PatchPredictor(model="resnet18-kather100k", batch_size=32)
256256

257257
save_dir = f"{save_dir_path}/model_wsi_output"
258258

@@ -290,7 +290,7 @@ def test_wsi_classifier_api(
290290
shutil.rmtree(_kwargs["save_dir"], ignore_errors=True)
291291

292292

293-
def test_patch_classifier_kather100k_output(
293+
def test_patch_predictor_kather100k_output(
294294
sample_patch1: Path,
295295
sample_patch2: Path,
296296
tmp_path: Path,
@@ -317,7 +317,7 @@ def test_patch_classifier_kather100k_output(
317317
"googlenet-kather100k": [1.0, 0.9999639987945557],
318318
}
319319
for model, expected_prob in pretrained_info.items():
320-
_test_classifier_output(
320+
_test_predictor_output(
321321
inputs,
322322
model,
323323
probabilities_check=expected_prob,
@@ -326,7 +326,7 @@ def test_patch_classifier_kather100k_output(
326326

327327
# cache mode
328328
for model, expected_prob in pretrained_info.items():
329-
_test_classifier_output(
329+
_test_predictor_output(
330330
inputs,
331331
model,
332332
probabilities_check=expected_prob,
@@ -376,19 +376,19 @@ def _validate_probabilities(output: list | dict | zarr.group) -> bool:
376376
return np.all(predictions[:][0:5] == [7, 3, 2, 3, 3])
377377

378378

379-
def test_wsi_classifier_zarr(
379+
def test_wsi_predictor_zarr(
380380
sample_wsi_dict: dict, tmp_path: Path, caplog: pytest.LogCaptureFixture
381381
) -> None:
382-
"""Test normal run of patch classifier for WSIs."""
382+
"""Test normal run of patch predictor for WSIs."""
383383
mini_wsi_svs = Path(sample_wsi_dict["wsi2_4k_4k_svs"])
384384

385-
classifier = PatchClassifier(
385+
predictor = PatchPredictor(
386386
model="alexnet-kather100k",
387387
batch_size=32,
388388
verbose=False,
389389
)
390390
# don't run test on GPU
391-
output = classifier.run(
391+
output = predictor.run(
392392
images=[mini_wsi_svs],
393393
return_probabilities=True,
394394
return_labels=False,
@@ -412,7 +412,7 @@ def test_wsi_classifier_zarr(
412412
assert _validate_probabilities(output=output_)
413413
assert "Output file saved at " in caplog.text
414414

415-
output = classifier.run(
415+
output = predictor.run(
416416
images=[mini_wsi_svs],
417417
return_probabilities=False,
418418
return_labels=False,
@@ -436,21 +436,21 @@ def test_wsi_classifier_zarr(
436436
assert "Output file saved at " in caplog.text
437437

438438

439-
def test_patch_classifier_patch_mode_annotation_store(
439+
def test_patch_predictor_patch_mode_annotation_store(
440440
sample_patch1: Path,
441441
sample_patch2: Path,
442442
tmp_path: Path,
443443
) -> None:
444444
"""Test the output of patch classification models on Kather100K dataset."""
445445
inputs = [Path(sample_patch1), Path(sample_patch2)]
446446

447-
classifier = PatchClassifier(
447+
predictor = PatchPredictor(
448448
model="alexnet-kather100k",
449449
batch_size=32,
450450
verbose=False,
451451
)
452452
# don't run test on GPU
453-
output = classifier.run(
453+
output = predictor.run(
454454
images=inputs,
455455
return_probabilities=True,
456456
return_labels=False,
@@ -467,21 +467,21 @@ def test_patch_classifier_patch_mode_annotation_store(
467467
assert np.all(np.array(output["probabilities"]) >= 0)
468468

469469

470-
def test_patch_classifier_patch_mode_no_probabilities(
470+
def test_patch_predictor_patch_mode_no_probabilities(
471471
sample_patch1: Path,
472472
sample_patch2: Path,
473473
tmp_path: Path,
474474
) -> None:
475475
"""Test the output of patch classification models on Kather100K dataset."""
476476
inputs = [Path(sample_patch1), Path(sample_patch2)]
477477

478-
classifier = PatchClassifier(
478+
predictor = PatchPredictor(
479479
model="alexnet-kather100k",
480480
batch_size=32,
481481
verbose=False,
482482
)
483483

484-
output = classifier.run(
484+
output = predictor.run(
485485
images=inputs,
486486
return_probabilities=False,
487487
return_labels=False,
@@ -492,7 +492,7 @@ def test_patch_classifier_patch_mode_no_probabilities(
492492
assert "probabilities" not in output
493493

494494
# don't run test on GPU
495-
output = classifier.run(
495+
output = predictor.run(
496496
images=inputs,
497497
return_probabilities=False,
498498
return_labels=False,
@@ -518,7 +518,7 @@ def test_engine_run_wsi_annotation_store(
518518
mini_wsi_svs = Path(sample_wsi_dict["wsi2_4k_4k_svs"])
519519
mini_wsi_msk = Path(sample_wsi_dict["wsi2_4k_4k_msk"])
520520

521-
eng = PatchClassifier(model="alexnet-kather100k")
521+
eng = PatchPredictor(model="alexnet-kather100k")
522522

523523
patch_size = np.array([224, 224])
524524
save_dir = f"{tmp_path}/model_wsi_output"
@@ -567,7 +567,7 @@ def test_cli_model_single_file(sample_svs: Path, tmp_path: Path) -> None:
567567
models_wsi_result = runner.invoke(
568568
cli.main,
569569
[
570-
"patch-classifier",
570+
"patch-predictor",
571571
"--img-input",
572572
str(sample_svs),
573573
"--patch-mode",
@@ -618,7 +618,7 @@ def test_cli_model_multiple_file_mask(remote_sample: Callable, tmp_path: Path) -
618618
models_tiles_result = runner.invoke(
619619
cli.main,
620620
[
621-
"patch-classifier",
621+
"patch-predictor",
622622
"--img-input",
623623
str(dir_path),
624624
"--patch-mode",

tiatoolbox/cli/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from tiatoolbox import __version__
99
from tiatoolbox.cli.common import tiatoolbox_cli
1010
from tiatoolbox.cli.nucleus_instance_segment import nucleus_instance_segment
11-
from tiatoolbox.cli.patch_classifier import patch_classifier
11+
from tiatoolbox.cli.patch_predictor import patch_predictor
1212
from tiatoolbox.cli.read_bounds import read_bounds
1313
from tiatoolbox.cli.save_tiles import save_tiles
1414
from tiatoolbox.cli.semantic_segment import semantic_segment
@@ -39,7 +39,7 @@ def main() -> click.BaseCommand:
3939

4040

4141
main.add_command(nucleus_instance_segment)
42-
main.add_command(patch_classifier)
42+
main.add_command(patch_predictor)
4343
main.add_command(read_bounds)
4444
main.add_command(save_tiles)
4545
main.add_command(semantic_segment)

0 commit comments

Comments
 (0)