diff --git a/tiatoolbox/models/dataset/classification.py b/tiatoolbox/models/dataset/classification.py index edf4e28aa..b73d2c9f3 100644 --- a/tiatoolbox/models/dataset/classification.py +++ b/tiatoolbox/models/dataset/classification.py @@ -163,7 +163,7 @@ class WSIPatchDataset(dataset_abc.PatchDatasetABC): """ - def __init__( # skipcq: PY-R1000 # noqa: PLR0915 + def __init__( # skipcq: PY-R1000 self: WSIPatchDataset, img_path: str | Path, mode: str = "wsi", @@ -262,40 +262,17 @@ def __init__( # skipcq: PY-R1000 # noqa: PLR0915 raise ValueError(msg) self.preproc_func = preproc_func - img_path = Path(img_path) - if mode == "wsi": - self.reader = WSIReader.open(img_path) - else: - logger.warning( - "WSIPatchDataset only reads image tile at " - '`units="baseline"` and `resolution=1.0`.', - stacklevel=2, - ) - img = imread(img_path) - axes = "YXS"[: len(img.shape)] - # initialise metadata for VirtualWSIReader. - # here, we simulate a whole-slide image, but with a single level. - # ! should we expose this so that use can provide their metadata ? - metadata = WSIMeta( - mpp=np.array([1.0, 1.0]), - axes=axes, - objective_power=10, - slide_dimensions=np.array(img.shape[:2][::-1]), - level_downsamples=[1.0], - level_dimensions=[np.array(img.shape[:2][::-1])], - ) - # infer value such that read if mask provided is through - # 'mpp' or 'power' as varying 'baseline' is locked atm + self.img_path = Path(img_path) + self.mode = mode + self.reader = None + reader = self._get_reader(self.img_path) + if mode != "wsi": units = "mpp" resolution = 1.0 - self.reader = VirtualWSIReader( - img, - info=metadata, - ) # may decouple into misc ? # the scaling factor will scale base level to requested read resolution/units - wsi_shape = self.reader.slide_dimensions(resolution=resolution, units=units) + wsi_shape = reader.slide_dimensions(resolution=resolution, units=units) # use all patches, as long as it overlaps source image self.inputs = PatchExtractor.get_coordinates( @@ -316,13 +293,13 @@ def __init__( # skipcq: PY-R1000 # noqa: PLR0915 mask = np.array(mask > 0, dtype=np.uint8) mask_reader = VirtualWSIReader(mask) - mask_reader.info = self.reader.info + mask_reader.info = reader.info elif auto_get_mask and mode == "wsi" and mask_path is None: # if no mask provided and `wsi` mode, generate basic tissue # mask on the fly - mask_reader = self.reader.tissue_mask(resolution=1.25, units="power") + mask_reader = reader.tissue_mask(resolution=1.25, units="power") # ? will this mess up ? - mask_reader.info = self.reader.info + mask_reader.info = reader.info if mask_reader is not None: selected = PatchExtractor.filter_coordinates( @@ -344,10 +321,44 @@ def __init__( # skipcq: PY-R1000 # noqa: PLR0915 # Perform check on the input self._check_input_integrity(mode="wsi") + def _get_reader(self: WSIPatchDataset, img_path: str | Path) -> WSIReader: + """Get a reader for the image.""" + if self.mode == "wsi": + reader = WSIReader.open(img_path) + else: + logger.warning( + "WSIPatchDataset only reads image tile at " + '`units="baseline"` and `resolution=1.0`.', + stacklevel=2, + ) + img = imread(img_path) + axes = "YXS"[: len(img.shape)] + # initialise metadata for VirtualWSIReader. + # here, we simulate a whole-slide image, but with a single level. + # ! should we expose this so that use can provide their metadata ? + metadata = WSIMeta( + mpp=np.array([1.0, 1.0]), + axes=axes, + objective_power=10, + slide_dimensions=np.array(img.shape[:2][::-1]), + level_downsamples=[1.0], + level_dimensions=[np.array(img.shape[:2][::-1])], + ) + # infer value such that read if mask provided is through + # 'mpp' or 'power' as varying 'baseline' is locked atm + reader = VirtualWSIReader( + img, + info=metadata, + ) + return reader + def __getitem__(self: WSIPatchDataset, idx: int) -> dict: """Get an item from the dataset.""" coords = self.inputs[idx] # Read image patch from the whole-slide image + if self.reader is None: + # only set the reader on first call so that it is initially picklable + self.reader = self._get_reader(self.img_path) patch = self.reader.read_bounds( coords, resolution=self.resolution,