1111import numpy as np
1212import pytest
1313import torchvision .models as torch_models
14- import zarr
1514from typing_extensions import Unpack
1615
1716from tiatoolbox .models .architecture import (
2625 prepare_engines_save_dir ,
2726)
2827from tiatoolbox .models .engine .io_config import ModelIOConfigABC
29- from tiatoolbox .utils .misc import write_to_zarr_in_cache_mode
3028
3129if TYPE_CHECKING :
3230 import torch .nn
@@ -62,19 +60,6 @@ def get_dataloader(
6260 patch_mode = patch_mode ,
6361 )
6462
65- def save_wsi_output (
66- self : EngineABC ,
67- processed_output : dict ,
68- save_dir : Path ,
69- ** kwargs : dict ,
70- ) -> Path :
71- """Test post_process_wsi."""
72- return super ().save_wsi_output (
73- processed_output ,
74- save_dir = save_dir ,
75- ** kwargs ,
76- )
77-
7863 def post_process_wsi (
7964 self : EngineABC ,
8065 raw_predictions : dict | Path ,
@@ -100,16 +85,6 @@ def infer_wsi(
10085 )
10186
10287
103- def test_engine_abc () -> NoReturn :
104- """Test EngineABC initialization."""
105- with pytest .raises (
106- TypeError ,
107- match = r".*Can't instantiate abstract class EngineABC*" ,
108- ):
109- # Can't instantiate abstract class with abstract methods
110- EngineABC () # skipcq
111-
112-
11388def test_engine_abc_incorrect_model_type () -> NoReturn :
11489 """Test EngineABC initialization with incorrect model type."""
11590 with pytest .raises (
@@ -295,7 +270,7 @@ def test_engine_initalization() -> NoReturn:
295270 assert isinstance (eng , EngineABC )
296271
297272
298- def test_engine_run (tmp_path : Path , sample_svs : Path ) -> NoReturn :
273+ def test_engine_run () -> NoReturn :
299274 """Test engine run."""
300275 eng = TestEngineABC (model = "alexnet-kather100k" )
301276 assert isinstance (eng , EngineABC )
@@ -372,14 +347,10 @@ def test_engine_run(tmp_path: Path, sample_svs: Path) -> NoReturn:
372347 assert "probabilities" in out
373348 assert "labels" in out
374349
375- eng = TestEngineABC (model = "alexnet-kather100k" )
376-
377- with pytest .raises (NotImplementedError ):
378- eng .run (
379- images = [sample_svs ],
380- save_dir = tmp_path / "output" ,
381- patch_mode = False ,
382- )
350+ pred = eng .post_process_wsi (
351+ raw_predictions = Path ("/path/to/raw_predictions.npy" ),
352+ )
353+ assert str (pred ) == "/path/to/raw_predictions.npy"
383354
384355
385356def test_engine_run_with_verbose () -> NoReturn :
@@ -542,55 +513,6 @@ def test_get_dataloader(sample_svs: Path) -> None:
542513 assert isinstance (dataloader .dataset , WSIPatchDataset )
543514
544515
545- def test_eng_save_output (tmp_path : pytest .TempPathFactory ) -> None :
546- """Test the eng.save_output() function."""
547- eng = TestEngineABC (model = "alexnet-kather100k" )
548- save_path = tmp_path / "output.zarr"
549- _ = zarr .open (save_path , mode = "w" )
550- out = eng .save_wsi_output (
551- processed_output = save_path ,
552- save_path = save_path ,
553- output_type = "zarr" ,
554- save_dir = tmp_path ,
555- )
556-
557- assert out .exists ()
558- assert out .suffix == ".zarr"
559-
560- # Test AnnotationStore
561- patch_output = {
562- "predictions" : np .array ([1 , 0 , 1 ]),
563- "coordinates" : np .array ([(0 , 0 , 1 , 1 ), (1 , 1 , 2 , 2 ), (2 , 2 , 3 , 3 )]),
564- }
565- class_dict = {0 : "class0" , 1 : "class1" }
566- save_path = tmp_path / "output_db.zarr"
567- zarr_group = zarr .open (save_path , mode = "w" )
568- _ = write_to_zarr_in_cache_mode (
569- zarr_group = zarr_group , output_data_to_save = patch_output
570- )
571- out = eng .save_wsi_output (
572- processed_output = save_path ,
573- scale_factor = (1.0 , 1.0 ),
574- class_dict = class_dict ,
575- save_dir = tmp_path ,
576- output_type = "AnnotationStore" ,
577- )
578-
579- assert out .exists ()
580- assert out .suffix == ".db"
581-
582- with pytest .raises (
583- ValueError ,
584- match = r".*supports zarr and AnnotationStore as output_type." ,
585- ):
586- eng .save_wsi_output (
587- processed_output = save_path ,
588- save_path = save_path ,
589- output_type = "dict" ,
590- save_dir = tmp_path ,
591- )
592-
593-
594516def test_io_config_delegation (tmp_path : Path , caplog : pytest .LogCaptureFixture ) -> None :
595517 """Test for delegating args to io config."""
596518 # test not providing config / full input info for not pretrained models
@@ -701,16 +623,3 @@ def test_io_config_delegation(tmp_path: Path, caplog: pytest.LogCaptureFixture)
701623 resolution = _kwargs ["resolution" ],
702624 units = _kwargs ["units" ],
703625 )
704-
705-
706- def test_notimplementederror_wsi_mode (
707- sample_svs : Path , tmp_path : pytest .TempPathFactory
708- ) -> None :
709- """Test that NotImplementedError is raised when wsi mode is False.
710-
711- A user should implement run method when patch_mode is False.
712-
713- """
714- eng = TestEngineABC (model = "alexnet-kather100k" )
715- with pytest .raises (NotImplementedError ):
716- eng .run (images = [sample_svs ], patch_mode = False , save_dir = tmp_path / "output" )
0 commit comments