diff --git a/tests/conftest.py b/tests/conftest.py index 1f1f0e6ae..4d1a01f46 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,6 +3,7 @@ import pathlib import shutil from pathlib import Path +from tempfile import NamedTemporaryFile from typing import Callable import pytest @@ -59,6 +60,28 @@ def __remote_sample(key: str) -> pathlib.Path: return __remote_sample +@pytest.fixture(scope="session") +def blank_sample(tmp_path_factory: TempPathFactory): + """Factory fixture for creating blank sample files.""" + + class BlankSample: + """Sample file. Automatically deleted after use.""" + + def __init__(self, suffix: str): + self.suffix = suffix + self.file = None # will be set in __enter__ + + def __enter__(self) -> pathlib.Path: + folder = tmp_path_factory.mktemp("data") + self.file = NamedTemporaryFile(suffix=self.suffix, dir=folder, delete=True) + return pathlib.Path(self.file.name) + + def __exit__(self, exc_type, exc_value, traceback): + self.file.close() + + return BlankSample + + @pytest.fixture(scope="session") def sample_ndpi(remote_sample) -> pathlib.Path: """Sample pytest fixture for ndpi images. diff --git a/tests/models/test_patch_predictor.py b/tests/models/test_patch_predictor.py index f79e8a031..87f11bdec 100644 --- a/tests/models/test_patch_predictor.py +++ b/tests/models/test_patch_predictor.py @@ -25,7 +25,7 @@ ) from tiatoolbox.utils import env_detection as toolbox_env from tiatoolbox.utils.misc import download_data, imread, imwrite -from tiatoolbox.wsicore.wsireader import WSIReader +from tiatoolbox.wsicore.wsireader import VirtualWSIReader, WSIReader ON_GPU = toolbox_env.has_gpu() @@ -224,9 +224,9 @@ def test_wsi_patch_dataset(sample_wsi_dict, tmp_path): mini_wsi_jpg = pathlib.Path(sample_wsi_dict["wsi2_4k_4k_jpg"]) mini_wsi_msk = pathlib.Path(sample_wsi_dict["wsi2_4k_4k_msk"]) - def reuse_init(img_path=mini_wsi_svs, **kwargs): + def reuse_init(input_img=mini_wsi_svs, **kwargs): """Testing function.""" - return WSIPatchDataset(img_path=img_path, **kwargs) + return WSIPatchDataset(input_img=input_img, **kwargs) def reuse_init_wsi(**kwargs): """Testing function.""" @@ -251,9 +251,9 @@ def __getitem__(self, idx): Proto() # skipcq # invalid path input - with pytest.raises(ValueError, match=r".*`img_path` must be a valid file path.*"): + with pytest.raises(ValueError, match=r".*`input_img` path must exist.*"): WSIPatchDataset( - img_path="aaaa", + input_img="aaaa", mode="wsi", patch_input_shape=[512, 512], stride_shape=[256, 256], @@ -261,10 +261,23 @@ def __getitem__(self, idx): ) # invalid mask path input - with pytest.raises(ValueError, match=r".*`mask_path` must be a valid file path.*"): + with pytest.raises(ValueError, match=r".*`mask` must be a valid file path.*"): WSIPatchDataset( - img_path=mini_wsi_svs, - mask_path="aaaa", + input_img=mini_wsi_svs, + mask="aaaa", + mode="wsi", + patch_input_shape=[512, 512], + stride_shape=[256, 256], + resolution=1.0, + units="mpp", + auto_get_mask=False, + ) + + # mask as not VirtualWSIReader + with pytest.raises(ValueError, match=r".*`mask` must be .* VirtualWSIReader.*"): + WSIPatchDataset( + input_img=mini_wsi_svs, + mask=WSIReader.open(mini_wsi_svs), mode="wsi", patch_input_shape=[512, 512], stride_shape=[256, 256], @@ -277,6 +290,10 @@ def __getitem__(self, idx): with pytest.raises(ValueError, match="`X` is not supported."): reuse_init(mode="X") + # invalid units + with pytest.raises(ValueError, match="`X` is not supported."): + reuse_init(units="X") + # invalid patch with pytest.raises(ValueError, match="Invalid `patch_input_shape` value None."): reuse_init() @@ -348,8 +365,8 @@ def __getitem__(self, idx): ) assert len(ds) > 0 ds = WSIPatchDataset( - img_path=mini_wsi_svs, - mask_path=mini_wsi_msk, + input_img=mini_wsi_svs, + mask=mini_wsi_msk, mode="wsi", patch_input_shape=[512, 512], stride_shape=[256, 256], @@ -363,8 +380,8 @@ def __getitem__(self, idx): imwrite(negative_mask_path, negative_mask) with pytest.raises(ValueError, match="No patch coordinates remain after filtering"): ds = WSIPatchDataset( - img_path=mini_wsi_svs, - mask_path=negative_mask_path, + input_img=mini_wsi_svs, + mask=negative_mask_path, mode="wsi", patch_input_shape=[512, 512], stride_shape=[256, 256], @@ -376,7 +393,7 @@ def __getitem__(self, idx): # * for tile reader = WSIReader.open(mini_wsi_jpg) tile_ds = WSIPatchDataset( - img_path=mini_wsi_jpg, + input_img=mini_wsi_jpg, mode="tile", patch_input_shape=patch_size, stride_shape=stride_size, @@ -397,6 +414,72 @@ def __getitem__(self, idx): assert roi1.shape[1] == roi2.shape[1] assert np.min(correlation) > 0.9, correlation + positive_mask = (negative_mask + 1).astype(bool) + # check mask as np array + with pytest.raises(ValueError, match=r".*`mask` must be binary.*"): + WSIPatchDataset( + input_img=mini_wsi_svs, + mask=np.array([[0, 0, 1]]), + mode="wsi", + patch_input_shape=[512, 512], + stride_shape=[256, 256], + auto_get_mask=False, + resolution=1.0, + units="mpp", + ) + ds = WSIPatchDataset( + input_img=mini_wsi_svs, + mask=positive_mask, + mode="wsi", + patch_input_shape=[512, 512], + stride_shape=[256, 256], + auto_get_mask=False, + resolution=1.0, + units="mpp", + ) + + assert len(ds) > 0 + + # check mask VirtualWSIReader + with pytest.raises(ValueError, match=r".*`mask` must be binary.*"): + WSIPatchDataset( + input_img=mini_wsi_svs, + mask=VirtualWSIReader(np.array([[0, 0, 5]])), + mode="wsi", + patch_input_shape=[512, 512], + stride_shape=[256, 256], + auto_get_mask=False, + resolution=1.0, + units="mpp", + ) + + ds_from_fp = WSIPatchDataset( + input_img=mini_wsi_svs, + mask=VirtualWSIReader(positive_mask, mode="bool"), + mode="wsi", + patch_input_shape=[512, 512], + stride_shape=[256, 256], + auto_get_mask=False, + resolution=1.0, + units="mpp", + ) + + assert len(ds_from_fp) > 0 + + mini_wsi_svs_np = imread(mini_wsi_svs) + ds_from_np = WSIPatchDataset( + input_img=mini_wsi_svs_np, + mask=VirtualWSIReader(positive_mask, mode="bool"), + mode="wsi", + patch_input_shape=[512, 512], + stride_shape=[256, 256], + auto_get_mask=False, + resolution=1.0, + units="baseline", + ) + + assert len(ds_from_np) > 0 + def test_patch_dataset_abc(): """Test for ABC methods. @@ -493,6 +576,12 @@ def test_predictor_crash(): predictor.predict([1, 2, 3], masks=[1, 2], mode="wsi") with pytest.raises(ValueError, match=r".*labels.*!=.*imgs.*"): predictor.predict([1, 2, 3], labels=[1, 2], mode="patch") + # mask on patch are not supported + with pytest.raises(ValueError, match=r".*masks are not supported .* `patch`.*"): + predictor.predict( + [np.array([1, 2, 3])], masks=[np.array([1, 2, 3])], mode="patch" + ) + # remove previously generated data _rm_dir("output") @@ -677,7 +766,7 @@ def test_patch_predictor_api(sample_patch1, sample_patch2, tmp_path): # test prediction predictor = PatchPredictor(model=model, batch_size=1, verbose=False) output = predictor.predict( - inputs, + imgs=inputs, return_probabilities=True, labels=[1, "a"], return_labels=True, @@ -803,6 +892,58 @@ def test_wsi_predictor_api(sample_wsi_dict, tmp_path): # remove previously generated data _rm_dir("output") + # check that predictor can take in WSIReader object + svs_objects = [WSIReader.open(i) for i in [mini_wsi_svs, mini_wsi_svs]] + output = predictor.predict( + svs_objects, + masks=[mini_wsi_msk, mini_wsi_msk], + mode="wsi", + **kwargs, + ) + assert str(mini_wsi_svs) in output + # remove previously generated data + _rm_dir(kwargs["save_dir"]) + + # check that predictor can take in ndarray object + img_objects = [ + WSIReader.open(i).slide_thumbnail(1, "baseline") + for i in [mini_wsi_svs, mini_wsi_svs] + ] + + with pytest.raises(ValueError, match=".*Cannot determine scale.*"): + predictor.predict( + img_objects, + masks=[mini_wsi_msk, mini_wsi_msk], + mode="wsi", + **kwargs, + ) + _rm_dir(kwargs["save_dir"]) + + _ = predictor.predict( + img_objects, + masks=[mini_wsi_msk, mini_wsi_msk], + mode="tile", + ignore_resolutions=True, + **kwargs, + ) + _rm_dir(kwargs["save_dir"]) + + _kwargs = copy.deepcopy(kwargs) + _kwargs["units"] = "baseline" + _kwargs["resolution"] = 1.0 + + output = predictor.predict( + img_objects, + masks=[mini_wsi_msk, mini_wsi_msk], + mode="wsi", + **_kwargs, + ) + + assert len(output) == 2 + assert 0 in output + assert 1 in output + _rm_dir(_kwargs["save_dir"]) + def test_wsi_predictor_merge_predictions(sample_wsi_dict): """Test normal run of wsi predictor with merge predictions option.""" diff --git a/tests/test_patch_extraction.py b/tests/test_patch_extraction.py index 5a6faaad1..61bbf47af 100644 --- a/tests/test_patch_extraction.py +++ b/tests/test_patch_extraction.py @@ -98,10 +98,9 @@ def test_get_patch_extractor(source_image, patch_extr_csv): def test_points_patch_extractor_image_format( - sample_svs, sample_jp2, source_image, patch_extr_csv + sample_svs, sample_jp2, source_image, patch_extr_csv, blank_sample ): """Test PointsPatchExtractor returns the right object.""" - file_parent_dir = pathlib.Path(__file__).parent locations_list = pathlib.Path(patch_extr_csv) points = patchextraction.get_patch_extractor( @@ -131,10 +130,9 @@ def test_points_patch_extractor_image_format( assert isinstance(points.wsi, OmnyxJP2WSIReader) - false_image = pathlib.Path(file_parent_dir.joinpath("data/source_image.test")) - with pytest.raises(FileNotSupported): + with blank_sample(".test") as false_image_path, pytest.raises(FileNotSupported): _ = patchextraction.get_patch_extractor( - input_img=false_image, + input_img=false_image_path, locations_list=locations_list, method_name="point", patch_size=(200, 200), diff --git a/tests/test_wsireader.py b/tests/test_wsireader.py index 0a4ec0443..eabf99729 100644 --- a/tests/test_wsireader.py +++ b/tests/test_wsireader.py @@ -1386,11 +1386,11 @@ def test_invalid_masker_method(sample_svs): def test_wsireader_open( - sample_svs, sample_ndpi, sample_jp2, sample_ome_tiff, source_image + sample_svs, sample_ndpi, sample_jp2, sample_ome_tiff, source_image, blank_sample ): """Test WSIReader.open() to return correct object.""" - with pytest.raises(FileNotSupported): - _ = WSIReader.open("./sample.csv") + with blank_sample(".csv") as path, pytest.raises(FileNotSupported): + _ = WSIReader.open(path) with pytest.raises(TypeError): _ = WSIReader.open([1, 2]) @@ -1656,13 +1656,15 @@ def test_command_line_read_bounds(sample_ndpi, tmp_path): def test_command_line_jp2_read_bounds(sample_jp2, tmp_path): """Test JP2 read_bounds.""" + input_img = pathlib.Path(sample_jp2) + runner = CliRunner() read_bounds_result = runner.invoke( cli.main, [ "read-bounds", "--img-input", - str(pathlib.Path(sample_jp2)), + str(input_img), "--resolution", "0", "--units", @@ -1673,7 +1675,9 @@ def test_command_line_jp2_read_bounds(sample_jp2, tmp_path): ) assert read_bounds_result.exit_code == 0 - assert pathlib.Path(tmp_path).joinpath("../im_region.jpg").is_file() + input_dir = pathlib.Path(input_img).parent.parent + output_path = os.path.join(input_dir, "im_region.jpg") + assert pathlib.Path(output_path).is_file() @pytest.mark.skipif( @@ -1701,23 +1705,25 @@ def test_command_line_jp2_read_bounds_show(sample_jp2, tmp_path): assert read_bounds_result.exit_code == 0 -def test_command_line_unsupported_file_read_bounds(sample_svs, tmp_path): +def test_command_line_unsupported_file_read_bounds(sample_svs, tmp_path, blank_sample): """Test unsupported file read bounds.""" runner = CliRunner() - read_bounds_result = runner.invoke( - cli.main, - [ - "read-bounds", - "--img-input", - str(pathlib.Path(sample_svs))[:-1], - "--resolution", - "0", - "--units", - "level", - "--mode", - "save", - ], - ) + + with blank_sample(".csv") as file: + read_bounds_result = runner.invoke( + cli.main, + [ + "read-bounds", + "--img-input", + str(file), + "--resolution", + "0", + "--units", + "level", + "--mode", + "save", + ], + ) assert read_bounds_result.output == "" assert read_bounds_result.exit_code == 1 @@ -1985,7 +1991,9 @@ def test_store_reader_alpha(remote_sample): def test_store_reader_no_types(tmp_path, remote_sample): - """Test AnnotationStoreReader with no types.""" + """ + Test AnnotationStoreReader with no types. + """ SQLiteStore(tmp_path / "store.db") wsi_reader = WSIReader.open(remote_sample("svs-1-small")) reader = AnnotationStoreReader(tmp_path / "store.db", wsi_reader.info) diff --git a/tiatoolbox/models/dataset/classification.py b/tiatoolbox/models/dataset/classification.py index a0adc942b..72a6c43ea 100644 --- a/tiatoolbox/models/dataset/classification.py +++ b/tiatoolbox/models/dataset/classification.py @@ -1,16 +1,15 @@ import os import pathlib +from typing import Literal, Tuple, Union import cv2 import numpy as np import PIL import torchvision.transforms as transforms -from tiatoolbox import logger from tiatoolbox.models.dataset import dataset_abc from tiatoolbox.tools.patchextraction import PatchExtractor from tiatoolbox.utils.misc import imread -from tiatoolbox.wsicore.wsimeta import WSIMeta from tiatoolbox.wsicore.wsireader import VirtualWSIReader, WSIReader @@ -115,9 +114,9 @@ def __getitem__(self, idx): data = { "image": patch, } + if self.labels is not None: data["label"] = self.labels[idx] - return data return data @@ -149,13 +148,13 @@ class WSIPatchDataset(dataset_abc.PatchDatasetABC): def __init__( self, - img_path, - mode="wsi", - mask_path=None, - patch_input_shape=None, - stride_shape=None, - resolution=None, - units=None, + input_img: Union[str, pathlib.Path, np.ndarray, WSIReader], + mode: Literal["wsi", "tile"] = "wsi", + mask: Union[str, pathlib.Path, np.ndarray, VirtualWSIReader] = None, + patch_input_shape: Union[Tuple[int, int], np.ndarray] = None, + stride_shape: Union[Tuple[int, int], np.ndarray] = None, + resolution: float = 1, + units: str = "baseline", auto_get_mask=True, min_mask_ratio=0, preproc_func=None, @@ -167,10 +166,16 @@ def __init__( Can be either `wsi` or `tile` to denote the image to read is either a whole-slide image or a large image tile. - img_path (:obj:`str` or :obj:`pathlib.Path`): + input_img: (:obj:`str` or + :obj:`pathlib.Path` or + :obj:`ndarray` or + :obj:`WSIReader`): Valid to pyramidal whole-slide image or large tile to read. - mask_path (:obj:`str` or :obj:`pathlib.Path`): + mask (:obj:`str` or + :obj:`pathlib.Path` or + :obj:`ndarray` or + :obj:`VirtualWSIReader`): Valid mask image. patch_input_shape: A tuple (int, int) or ndarray of shape (2,). Expected @@ -184,10 +189,10 @@ def __init__( `units`. Expected to be positive and of (height, width). Note, this is not at level 0. resolution: - Check (:class:`.WSIReader`) for details. When - `mode='tile'`, value is fixed to be `resolution=1.0` and - `units='baseline'` units: check (:class:`.WSIReader`) for - details. + Check (:class:`.WSIReader`) for details. + If reading from an image without specified metadata, + use `resolution=1.0` and`units='baseline'` units: + check (:class:`.WSIReader`) for details. units: Units in which `resolution` is defined. auto_get_mask: @@ -219,11 +224,12 @@ def __init__( """ super().__init__() - # Is there a generic func for path test in toolbox? - if not os.path.isfile(img_path): - raise ValueError("`img_path` must be a valid file path.") if mode not in ["wsi", "tile"]: raise ValueError(f"`{mode}` is not supported.") + + if units not in ["baseline", "power", "mpp"]: + raise ValueError(f"`{units}` is not supported.") + patch_input_shape = np.array(patch_input_shape) stride_shape = np.array(stride_shape) @@ -241,38 +247,7 @@ def __init__( raise ValueError(f"Invalid `stride_shape` value {stride_shape}.") self.preproc_func = preproc_func - img_path = pathlib.Path(img_path) - if mode == "wsi": - self.reader = WSIReader.open(img_path) - else: - logger.warning( - "WSIPatchDataset only reads image tile at " - '`units="baseline"` and `resolution=1.0`.', - stacklevel=2, - ) - units = "baseline" - resolution = 1.0 - img = imread(img_path) - axes = "YXS"[: len(img.shape)] - # initialise metadata for VirtualWSIReader. - # here, we simulate a whole-slide image, but with a single level. - # ! should we expose this so that use can provide their metadata ? - metadata = WSIMeta( - mpp=np.array([1.0, 1.0]), - axes=axes, - objective_power=10, - slide_dimensions=np.array(img.shape[:2][::-1]), - level_downsamples=[1.0], - level_dimensions=[np.array(img.shape[:2][::-1])], - ) - # hack value such that read if mask is provided is through - # 'mpp' or 'power' as varying 'baseline' is locked atm - units = "mpp" - resolution = 1.0 - self.reader = VirtualWSIReader( - img, - info=metadata, - ) + self.reader = WSIReader.open(input_img) # may decouple into misc ? # the scaling factor will scale base level to requested read resolution/units @@ -286,17 +261,69 @@ def __init__( input_within_bound=False, ) + self._apply_mask( + mask=mask, + auto_get_mask=auto_get_mask, + wsi_shape=wsi_shape, + min_mask_ratio=min_mask_ratio, + mode=mode, + ) + + if len(self.inputs) == 0: + raise ValueError("No patch coordinates remain after filtering.") + + self.patch_input_shape = patch_input_shape + self.resolution = resolution + self.units = units + + # Perform check on the input + self._check_input_integrity(mode="wsi") + + def _apply_mask( + self, + mask: Union[str, pathlib.Path, np.ndarray, VirtualWSIReader], + auto_get_mask: bool = True, + wsi_shape: Tuple[int, int] = None, + min_mask_ratio: float = 0, + mode: Literal["wsi", "tile"] = "wsi", + ): + """Reads or generates a mask for the input image and + applies it to the dataset.""" mask_reader = None - if mask_path is not None: - if not os.path.isfile(mask_path): - raise ValueError("`mask_path` must be a valid file path.") - mask = imread(mask_path) # assume to be gray - mask = cv2.cvtColor(mask, cv2.COLOR_RGB2GRAY) - mask = np.array(mask > 0, dtype=np.uint8) - - mask_reader = VirtualWSIReader(mask) + if mask is not None: + if not isinstance(mask, (str, pathlib.Path, np.ndarray, VirtualWSIReader)): + raise ValueError( + "`mask` must be file path, np.ndarray or VirtualWSIReader." + ) + + if isinstance(mask, VirtualWSIReader): + if mask.mode != "bool": + raise ValueError( + "`mask` must be binary, " + "i.e. VirtualWSIReader's mode has to be 'bool'" + ) + + mask_reader = mask + + elif isinstance(mask, np.ndarray): + if mask.dtype != bool: + raise ValueError( + "`mask` must be binary, i.e. `ndarray.dtype` has to be bool" + ) + + mask_reader = VirtualWSIReader(mask.astype(np.uint8)) + + else: # assume to be file path + if not os.path.isfile(mask): + raise ValueError("`mask` must be a valid file path.") + + mask = imread(mask) # assume to be gray + mask = cv2.cvtColor(mask, cv2.COLOR_RGB2GRAY) + mask = np.array(mask > 0, dtype=np.uint8) + mask_reader = VirtualWSIReader(mask) + mask_reader.info = self.reader.info - elif auto_get_mask and mode == "wsi" and mask_path is None: + elif auto_get_mask and mode == "wsi" and mask is None: # if no mask provided and `wsi` mode, generate basic tissue # mask on the fly mask_reader = self.reader.tissue_mask(resolution=1.25, units="power") @@ -312,16 +339,6 @@ def __init__( ) self.inputs = self.inputs[selected] - if len(self.inputs) == 0: - raise ValueError("No patch coordinates remain after filtering.") - - self.patch_input_shape = patch_input_shape - self.resolution = resolution - self.units = units - - # Perform check on the input - self._check_input_integrity(mode="wsi") - def __getitem__(self, idx): coords = self.inputs[idx] # Read image patch from the whole-slide image diff --git a/tiatoolbox/models/dataset/dataset_abc.py b/tiatoolbox/models/dataset/dataset_abc.py index 94751538d..f5f51b870 100644 --- a/tiatoolbox/models/dataset/dataset_abc.py +++ b/tiatoolbox/models/dataset/dataset_abc.py @@ -1,6 +1,7 @@ import os import pathlib from abc import ABC, abstractmethod +from typing import Union import numpy as np import torch @@ -98,14 +99,17 @@ def _check_input_integrity(self, mode): raise ValueError("`inputs` should be a list of patch coordinates.") @staticmethod - def load_img(path): + def load_img(input_img: Union[str, pathlib.Path, np.ndarray]) -> np.ndarray: """Load an image from a provided path. Args: - path (str): Path to an image file. + input_img (str, pathlib.Path, np.ndarray): path to image or image data. + + Returns: + np.ndarray: image data. """ - path = pathlib.Path(path) + path = pathlib.Path(input_img) if path.suffix not in (".npy", ".jpg", ".jpeg", ".tif", ".tiff", ".png"): raise ValueError(f"Cannot load image data from `{path.suffix}` files.") diff --git a/tiatoolbox/models/engine/patch_predictor.py b/tiatoolbox/models/engine/patch_predictor.py index 70b446c7f..7a1e3c90b 100644 --- a/tiatoolbox/models/engine/patch_predictor.py +++ b/tiatoolbox/models/engine/patch_predictor.py @@ -4,7 +4,7 @@ import os import pathlib from collections import OrderedDict -from typing import Callable, Tuple, Union +from typing import Callable, List, Literal, Tuple, Union import numpy as np import torch @@ -16,7 +16,7 @@ from tiatoolbox.models.engine.semantic_segmentor import IOSegmentorConfig from tiatoolbox.utils import misc from tiatoolbox.utils.misc import save_as_json -from tiatoolbox.wsicore.wsireader import VirtualWSIReader, WSIReader +from tiatoolbox.wsicore.wsireader import WSIReader class IOPatchPredictorConfig(IOSegmentorConfig): @@ -158,8 +158,6 @@ class PatchPredictor: Whether to output logging information. Attributes: - img (:obj:`str` or :obj:`pathlib.Path` or :obj:`numpy.ndarray`): - A HWC image or a path to WSI. mode (str): Type of input to process. Choose from either `patch`, `tile` or `wsi`. @@ -229,7 +227,6 @@ def __init__( ): super().__init__() - self.imgs = None self.mode = None if model is None and pretrained_model is None: @@ -251,7 +248,7 @@ def __init__( @staticmethod def merge_predictions( - img: Union[str, pathlib.Path, np.ndarray], + input_img: Union[str, pathlib.Path, np.ndarray, WSIReader], output: dict, resolution: float = None, units: str = None, @@ -267,8 +264,8 @@ def merge_predictions( predicted by the model. Args: - img (:obj:`str` or :obj:`pathlib.Path` or :class:`numpy.ndarray`): - A HWC image or a path to WSI. + input_img (:obj:`str` or :obj:`pathlib.Path` or :class:`numpy.ndarray`): + Image to be processed. This can be a WSI, tile or patch. output (dict): Output generated by the model. resolution (float): @@ -308,16 +305,7 @@ def merge_predictions( ... [0, 0, 1, 1]]) """ - reader = WSIReader.open(img) - if isinstance(reader, VirtualWSIReader): - logger.warning( - "Image is not pyramidal hence read is forced to be " - "at `units='baseline'` and `resolution=1.0`.", - stacklevel=2, - ) - resolution = 1.0 - units = "baseline" - + reader = WSIReader.open(input_img) canvas_shape = reader.slide_dimensions(resolution=resolution, units=units) canvas_shape = canvas_shape[::-1] # XY to YX @@ -557,16 +545,16 @@ def _prepare_save_dir(save_dir, imgs): return save_dir - def _predict_patch(self, imgs, labels, return_probabilities, return_labels, on_gpu): + def _predict_patch( + self, input_imgs, labels, return_probabilities, return_labels, on_gpu + ): """Process patch mode. Args: - imgs (list, ndarray): - List of inputs to process. when using `patch` mode, the - input must be either a list of images, a list of image - file paths or a numpy array of an image list. When using - `tile` or `wsi` mode, the input must be a list of file - paths. + input_imgs (list): + List of inputs to process. Must be either a list of + images, a list of image file paths, WSIReader objects, + or a numpy array of an image list. labels: List of labels. If using `tile` or `wsi` mode, then only a single label per image tile or whole-slide image is @@ -587,21 +575,21 @@ def _predict_patch(self, imgs, labels, return_probabilities, return_labels, on_g # if a labels is provided, then return with the prediction return_labels = bool(labels) - if labels and len(labels) != len(imgs): + if labels and len(labels) != len(input_imgs): raise ValueError( - f"len(labels) != len(imgs) : " f"{len(labels)} != {len(imgs)}" + f"len(labels) != len(imgs) : " f"{len(labels)} != {len(input_imgs)}" ) # don't return coordinates if patches are already extracted return_coordinates = False - dataset = PatchDataset(imgs, labels) + dataset = PatchDataset(input_imgs, labels) return self._predict_engine( dataset, return_probabilities, return_labels, return_coordinates, on_gpu ) def _predict_tile_wsi( self, - imgs, + input_imgs, masks, labels, mode, @@ -616,12 +604,10 @@ def _predict_tile_wsi( """Predict on Tile and WSIs. Args: - imgs (list, ndarray): - List of inputs to process. when using `patch` mode, the - input must be either a list of images, a list of image - file paths or a numpy array of an image list. When using - `tile` or `wsi` mode, the input must be a list of file - paths. + input_imgs (list): + List of inputs to process. Must be either a list of + images, a list of image file paths, WSIReader objects, + or a numpy array of an image list. masks (list): List of masks. Only utilised when processing image tiles and whole-slide images. Patches are only processed if @@ -678,22 +664,22 @@ def _predict_tile_wsi( # generate a list of output file paths if number of input images > 1 file_dict = OrderedDict() - if len(imgs) > 1: + if len(input_imgs) > 1: save_output = True - for idx, img_path in enumerate(imgs): - img_path = pathlib.Path(img_path) + for idx, input_img in enumerate(input_imgs): img_label = None if labels is None else labels[idx] img_mask = None if masks is None else masks[idx] dataset = WSIPatchDataset( - img_path, + input_img, mode=mode, - mask_path=img_mask, + mask=img_mask, patch_input_shape=ioconfig.patch_input_shape, stride_shape=ioconfig.stride_shape, resolution=ioconfig.input_resolutions[0]["resolution"], units=ioconfig.input_resolutions[0]["units"], + auto_get_mask=True, ) output_model = self._predict_engine( dataset, @@ -712,7 +698,7 @@ def _predict_tile_wsi( merged_prediction = None if merge_predictions: merged_prediction = self.merge_predictions( - img_path, + input_img, output_model, resolution=output_model["resolution"], units=output_model["units"], @@ -721,49 +707,81 @@ def _predict_tile_wsi( outputs.append(merged_prediction) if save_output: - # dynamic 0 padding - img_code = f"{idx:0{len(str(len(imgs)))}d}" - - save_info = {} - save_path = os.path.join(str(save_dir), img_code) - raw_save_path = f"{save_path}.raw.json" - save_info["raw"] = raw_save_path - save_as_json(output_model, raw_save_path) - if merge_predictions: - merged_file_path = f"{save_path}.merged.npy" - np.save(merged_file_path, merged_prediction) - save_info["merged"] = merged_file_path - file_dict[str(img_path)] = save_info + img_id, save_info = self._save_output( + output_model, + idx, + merged_prediction, + input_img, + input_imgs, + save_dir, + merge_predictions, + ) + file_dict[img_id] = save_info return file_dict if save_output else outputs + @staticmethod + def _save_output( + output_model, + idx, + merged_prediction, + input_img, + input_imgs, + save_dir, + merge_predictions, + ): + """Save prediction to json and/or numpy file.""" + # dynamic 0 padding + img_code = f"{idx:0{len(str(len(input_imgs)))}d}" + + save_info = {} + save_path = os.path.join(str(save_dir), img_code) + raw_save_path = f"{save_path}.raw.json" + save_info["raw"] = raw_save_path + save_as_json(output_model, raw_save_path) + if merge_predictions: + merged_file_path = f"{save_path}.merged.npy" + np.save(merged_file_path, merged_prediction) + save_info["merged"] = merged_file_path + + img_id = None + if isinstance(input_img, WSIReader): + img_id = str(input_img.input_path) + + elif isinstance(input_img, (str, pathlib.Path)): + img_id = str(input_img) + + if img_id is None: + img_id = idx + + return img_id, save_info + def predict( self, - imgs, - masks=None, - labels=None, - mode="patch", - return_probabilities=False, - return_labels=False, - on_gpu=True, + imgs: List[Union[str, pathlib.Path, np.ndarray, WSIReader]], + masks: List[Union[str, pathlib.Path, np.ndarray, WSIReader]] = None, + labels: List = None, + mode: Literal["patch", "tile", "wsi"] = "patch", + return_probabilities: bool = False, + return_labels: bool = False, + on_gpu: bool = True, ioconfig: IOPatchPredictorConfig = None, patch_input_shape: Tuple[int, int] = None, stride_shape: Tuple[int, int] = None, - resolution=None, - units=None, - merge_predictions=False, - save_dir=None, - save_output=False, + resolution: float = None, + units: str = None, + merge_predictions: bool = False, + save_dir: bool = None, + save_output: bool = False, + ignore_resolutions: bool = False, ): """Make a prediction for a list of input data. Args: - imgs (list, ndarray): - List of inputs to process. when using `patch` mode, the - input must be either a list of images, a list of image - file paths or a numpy array of an image list. When using - `tile` or `wsi` mode, the input must be a list of file - paths. + imgs (list): + List of inputs to process. Must be either a list of + images, a list of image file paths, WSIReader objects, + or a numpy array of an image list. masks (list): List of masks. Only utilised when processing image tiles and whole-slide images. Patches are only processed if @@ -796,7 +814,7 @@ def predict( level 0, and must be positive. If not provided, `stride_shape=patch_input_shape`. resolution (float): - Resolution used for reading the image. Please see + Resolution used for reading the images. Please see :obj:`WSIReader` for details. units (str): Units of resolution used for reading the image. Choose @@ -812,7 +830,11 @@ def predict( where the running script is invoked. save_output (bool): Whether to save output for a single file. default=False - + ignore_resolutions (bool): + Whether to ignore the resolution of the input images. + PatchPredictor won't rescale the input images and will + and use them in the original resolution. + Works with `mode='patch'` only. Returns: (:class:`numpy.ndarray`, dict): Model predictions of the input dataset. If multiple @@ -846,10 +868,9 @@ def predict( raise ValueError( f"{mode} is not a valid mode. Use either `patch`, `tile` or `wsi`" ) - if mode == "patch": - return self._predict_patch( - imgs, labels, return_probabilities, return_labels, on_gpu - ) + + if mode == "patch" and masks is not None: + raise ValueError("masks are not supported for `patch` mode. ") if not isinstance(imgs, list): raise ValueError( @@ -861,14 +882,21 @@ def predict( f"len(masks) != len(imgs) : " f"{len(masks)} != {len(imgs)}" ) + if mode == "patch": + return self._predict_patch( + imgs, labels, return_probabilities, return_labels, on_gpu + ) + ioconfig = self._update_ioconfig( ioconfig, patch_input_shape, stride_shape, resolution, units ) - if mode == "tile": + + if mode == "tile" and ignore_resolutions: logger.warning( "WSIPatchDataset only reads image tile at " '`units="baseline"`. Resolutions will be converted ' - "to baseline value.", + "to baseline value. " + "Set ignore_resolutions to False to change this behaviour.", stacklevel=2, ) ioconfig = ioconfig.to_baseline() diff --git a/tiatoolbox/tools/patchextraction.py b/tiatoolbox/tools/patchextraction.py index 24b9646c8..469a6595e 100644 --- a/tiatoolbox/tools/patchextraction.py +++ b/tiatoolbox/tools/patchextraction.py @@ -286,7 +286,9 @@ def filter_coordinates( tissue_mask = mask_reader.img # Scaling the coordinates_list to the `tissue_mask` array resolution - scale_factors = np.array(tissue_mask.shape[::-1]) / np.array(wsi_shape) + scale_factors = np.array(tissue_mask.shape[:2][::-1]) / np.array( + wsi_shape[:2] + ) # [:2] is to ignore the channel dimension if it exists scaled_coords = coordinates_list.copy().astype(np.float32) scaled_coords[:, [0, 2]] *= scale_factors[0] scaled_coords[:, [0, 2]] = np.clip( diff --git a/tiatoolbox/wsicore/wsireader.py b/tiatoolbox/wsicore/wsireader.py index 489df37ca..353043e94 100644 --- a/tiatoolbox/wsicore/wsireader.py +++ b/tiatoolbox/wsicore/wsireader.py @@ -253,6 +253,7 @@ def open( # noqa: A003 raise TypeError( "Invalid input: Must be a WSIRead, numpy array, string or pathlib.Path" ) + if isinstance(input_img, np.ndarray): return VirtualWSIReader(input_img, mpp=mpp, power=power) @@ -261,10 +262,28 @@ def open( # noqa: A003 # Input is a string or pathlib.Path, normalise to pathlib.Path input_path = pathlib.Path(input_img) + if not os.path.exists(input_path): + raise ValueError("`input_img` path must exist") + WSIReader.verify_supported_wsi(input_path) + return WSIReader.get_reader_by_filepath( + input_path, mpp=mpp, power=power, **kwargs + ) - # Handle special cases first (DICOM, Zarr/NGFF, OME-TIFF) + @staticmethod + def get_reader_by_filepath( + input_path: pathlib.Path, + mpp: Optional[Tuple[Number, Number]] = None, + power: Optional[Number] = None, + **kwargs, + ) -> WSIReader: + """ + Returns an appropriate :class:`.WSIReader` object + based on the file extension. + """ + + # Handle special cases first (DICOM, Zarr/NGFF, OME-TIFF) if is_dicom(input_path): return DICOMWSIReader(input_path, mpp=mpp, power=power)