Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 44 additions & 32 deletions tiatoolbox/models/dataset/dataset_abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,40 +458,18 @@ def __init__( # skipcq: PY-R1000 # noqa: PLR0915
raise ValueError(msg)

self.preproc_func = preproc_func
img_path = Path(img_path)
if mode == "wsi":
self.reader = WSIReader.open(img_path)
else:
logger.warning(
"WSIPatchDataset only reads image tile at "
'`units="baseline"` and `resolution=1.0`.',
stacklevel=2,
)
img = imread(img_path)
axes = "YXS"[: len(img.shape)]
# initialise metadata for VirtualWSIReader.
# here, we simulate a whole-slide image, but with a single level.
# ! should we expose this so that use can provide their metadata ?
metadata = WSIMeta(
mpp=np.array([1.0, 1.0]),
axes=axes,
objective_power=10,
slide_dimensions=np.array(img.shape[:2][::-1]),
level_downsamples=[1.0],
level_dimensions=[np.array(img.shape[:2][::-1])],
)
# infer value such that read if mask provided is through
# 'mpp' or 'power' as varying 'baseline' is locked atm
self.img_path = Path(img_path)
self.mode = mode
self.reader = None
reader = self._get_reader(self.img_path)
if mode != "wsi":
units = "mpp"
resolution = 1.0
self.reader = VirtualWSIReader(
img,
info=metadata,
)

# may decouple into misc ?
# the scaling factor will scale base level to requested read resolution/units
wsi_shape = self.reader.slide_dimensions(resolution=resolution, units=units)
wsi_shape = reader.slide_dimensions(resolution=resolution, units=units)
self.reader_info = reader.info

# use all patches, as long as it overlaps source image
self.inputs = PatchExtractor.get_coordinates(
Expand All @@ -512,13 +490,13 @@ def __init__( # skipcq: PY-R1000 # noqa: PLR0915
mask = np.array(mask > 0, dtype=np.uint8)

mask_reader = VirtualWSIReader(mask)
mask_reader.info = self.reader.info
mask_reader.info = reader.info
elif auto_get_mask and mode == "wsi" and mask_path is None:
# if no mask provided and `wsi` mode, generate basic tissue
# mask on the fly
mask_reader = self.reader.tissue_mask(resolution=1.25, units="power")
mask_reader = reader.tissue_mask(resolution=1.25, units="power")
# ? will this mess up ?
mask_reader.info = self.reader.info
mask_reader.info = reader.info

if mask_reader is not None:
selected = PatchExtractor.filter_coordinates(
Expand All @@ -540,10 +518,44 @@ def __init__( # skipcq: PY-R1000 # noqa: PLR0915
# Perform check on the input
self._check_input_integrity(mode="wsi")

def _get_reader(self: WSIPatchDataset, img_path: str | Path) -> WSIReader:
"""Get a reader for the image."""
if self.mode == "wsi":
reader = WSIReader.open(img_path)
else:
logger.warning(
"WSIPatchDataset only reads image tile at "
'`units="baseline"` and `resolution=1.0`.',
stacklevel=2,
)
img = imread(img_path)
axes = "YXS"[: len(img.shape)]
# initialise metadata for VirtualWSIReader.
# here, we simulate a whole-slide image, but with a single level.
# ! should we expose this so that use can provide their metadata ?
metadata = WSIMeta(
mpp=np.array([1.0, 1.0]),
axes=axes,
objective_power=10,
slide_dimensions=np.array(img.shape[:2][::-1]),
level_downsamples=[1.0],
level_dimensions=[np.array(img.shape[:2][::-1])],
)
# infer value such that read if mask provided is through
# 'mpp' or 'power' as varying 'baseline' is locked atm
reader = VirtualWSIReader(
img,
info=metadata,
)
return reader

def __getitem__(self: WSIPatchDataset, idx: int) -> dict:
"""Get an item from the dataset."""
coords = self.inputs[idx]
# Read image patch from the whole-slide image
if self.reader is None:
# only set the reader on first call so that it is initially picklable
self.reader = self._get_reader(self.img_path)
patch = self.reader.read_bounds(
coords,
resolution=self.resolution,
Expand Down
2 changes: 1 addition & 1 deletion tiatoolbox/models/engine/engine_abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -1022,7 +1022,7 @@ def _calculate_scale_factor(dataloader: DataLoader) -> float | tuple[float, floa
# equal to dataloader resolution.

if dataloader_units in ["mpp", "level", "power"]:
wsimeta_dict = dataloader.dataset.reader.info.as_dict()
wsimeta_dict = dataloader.dataset.reader_info.as_dict()

if dataloader_units == "mpp":
slide_resolution = wsimeta_dict[dataloader_units]
Expand Down