Skip to content

Commit 1fccf15

Browse files
shaneahmedmeastymostafajahanifaradamshephardJiaqi-Lv
authored
✨ Add PatchPredictor Engine (#865)
- Add `PatchPredictor` Engine based on `EngineABC` - Add `return_probabilities` option to Params - Removes `merge_predictions` option in `PatchPredictor` engine. - Defines `post_process_cache_mode` which allows running the algorithm on `WSI` - Add `infer_wsi` for WSI inference - Removes `save_wsi_output` as this is not required after post processing. - Removes `merge_predictions` and fixes docstring in EngineABCRunParams - `compile_model` is now moved to EngineABC init - Fixes bug with `_calculate_scale_factor` - Fixes a bug in `class_dict` definition. - `_get_zarr_array` is now a public function `get_zarr_array` in `misc` - `patch_predictions_as_annotations` runs the loop on `patch_coords` instead of `class_probs` --------- Co-authored-by: Mark Eastwood <[email protected]> Co-authored-by: Mostafa Jahanifar <[email protected]> Co-authored-by: Adam Shephard <[email protected]> Co-authored-by: Jiaqi-Lv <[email protected]>
1 parent dba269c commit 1fccf15

File tree

9 files changed

+511
-484
lines changed

9 files changed

+511
-484
lines changed

tests/engines/test_engine_abc.py

Lines changed: 5 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
import numpy as np
1212
import pytest
1313
import torchvision.models as torch_models
14-
import zarr
1514
from typing_extensions import Unpack
1615

1716
from tiatoolbox.models.architecture import (
@@ -26,7 +25,6 @@
2625
prepare_engines_save_dir,
2726
)
2827
from tiatoolbox.models.engine.io_config import ModelIOConfigABC
29-
from tiatoolbox.utils.misc import write_to_zarr_in_cache_mode
3028

3129
if TYPE_CHECKING:
3230
import torch.nn
@@ -62,19 +60,6 @@ def get_dataloader(
6260
patch_mode=patch_mode,
6361
)
6462

65-
def save_wsi_output(
66-
self: EngineABC,
67-
processed_output: dict,
68-
save_dir: Path,
69-
**kwargs: dict,
70-
) -> Path:
71-
"""Test post_process_wsi."""
72-
return super().save_wsi_output(
73-
processed_output,
74-
save_dir=save_dir,
75-
**kwargs,
76-
)
77-
7863
def post_process_wsi(
7964
self: EngineABC,
8065
raw_predictions: dict | Path,
@@ -100,16 +85,6 @@ def infer_wsi(
10085
)
10186

10287

103-
def test_engine_abc() -> NoReturn:
104-
"""Test EngineABC initialization."""
105-
with pytest.raises(
106-
TypeError,
107-
match=r".*Can't instantiate abstract class EngineABC*",
108-
):
109-
# Can't instantiate abstract class with abstract methods
110-
EngineABC() # skipcq
111-
112-
11388
def test_engine_abc_incorrect_model_type() -> NoReturn:
11489
"""Test EngineABC initialization with incorrect model type."""
11590
with pytest.raises(
@@ -295,7 +270,7 @@ def test_engine_initalization() -> NoReturn:
295270
assert isinstance(eng, EngineABC)
296271

297272

298-
def test_engine_run(tmp_path: Path, sample_svs: Path) -> NoReturn:
273+
def test_engine_run() -> NoReturn:
299274
"""Test engine run."""
300275
eng = TestEngineABC(model="alexnet-kather100k")
301276
assert isinstance(eng, EngineABC)
@@ -372,14 +347,10 @@ def test_engine_run(tmp_path: Path, sample_svs: Path) -> NoReturn:
372347
assert "probabilities" in out
373348
assert "labels" in out
374349

375-
eng = TestEngineABC(model="alexnet-kather100k")
376-
377-
with pytest.raises(NotImplementedError):
378-
eng.run(
379-
images=[sample_svs],
380-
save_dir=tmp_path / "output",
381-
patch_mode=False,
382-
)
350+
pred = eng.post_process_wsi(
351+
raw_predictions=Path("/path/to/raw_predictions.npy"),
352+
)
353+
assert str(pred) == "/path/to/raw_predictions.npy"
383354

384355

385356
def test_engine_run_with_verbose() -> NoReturn:
@@ -542,55 +513,6 @@ def test_get_dataloader(sample_svs: Path) -> None:
542513
assert isinstance(dataloader.dataset, WSIPatchDataset)
543514

544515

545-
def test_eng_save_output(tmp_path: pytest.TempPathFactory) -> None:
546-
"""Test the eng.save_output() function."""
547-
eng = TestEngineABC(model="alexnet-kather100k")
548-
save_path = tmp_path / "output.zarr"
549-
_ = zarr.open(save_path, mode="w")
550-
out = eng.save_wsi_output(
551-
processed_output=save_path,
552-
save_path=save_path,
553-
output_type="zarr",
554-
save_dir=tmp_path,
555-
)
556-
557-
assert out.exists()
558-
assert out.suffix == ".zarr"
559-
560-
# Test AnnotationStore
561-
patch_output = {
562-
"predictions": np.array([1, 0, 1]),
563-
"coordinates": np.array([(0, 0, 1, 1), (1, 1, 2, 2), (2, 2, 3, 3)]),
564-
}
565-
class_dict = {0: "class0", 1: "class1"}
566-
save_path = tmp_path / "output_db.zarr"
567-
zarr_group = zarr.open(save_path, mode="w")
568-
_ = write_to_zarr_in_cache_mode(
569-
zarr_group=zarr_group, output_data_to_save=patch_output
570-
)
571-
out = eng.save_wsi_output(
572-
processed_output=save_path,
573-
scale_factor=(1.0, 1.0),
574-
class_dict=class_dict,
575-
save_dir=tmp_path,
576-
output_type="AnnotationStore",
577-
)
578-
579-
assert out.exists()
580-
assert out.suffix == ".db"
581-
582-
with pytest.raises(
583-
ValueError,
584-
match=r".*supports zarr and AnnotationStore as output_type.",
585-
):
586-
eng.save_wsi_output(
587-
processed_output=save_path,
588-
save_path=save_path,
589-
output_type="dict",
590-
save_dir=tmp_path,
591-
)
592-
593-
594516
def test_io_config_delegation(tmp_path: Path, caplog: pytest.LogCaptureFixture) -> None:
595517
"""Test for delegating args to io config."""
596518
# test not providing config / full input info for not pretrained models
@@ -701,16 +623,3 @@ def test_io_config_delegation(tmp_path: Path, caplog: pytest.LogCaptureFixture)
701623
resolution=_kwargs["resolution"],
702624
units=_kwargs["units"],
703625
)
704-
705-
706-
def test_notimplementederror_wsi_mode(
707-
sample_svs: Path, tmp_path: pytest.TempPathFactory
708-
) -> None:
709-
"""Test that NotImplementedError is raised when wsi mode is False.
710-
711-
A user should implement run method when patch_mode is False.
712-
713-
"""
714-
eng = TestEngineABC(model="alexnet-kather100k")
715-
with pytest.raises(NotImplementedError):
716-
eng.run(images=[sample_svs], patch_mode=False, save_dir=tmp_path / "output")

0 commit comments

Comments
 (0)