Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 44 additions & 33 deletions tiatoolbox/models/dataset/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ class WSIPatchDataset(dataset_abc.PatchDatasetABC):

"""

def __init__( # skipcq: PY-R1000 # noqa: PLR0915
def __init__( # skipcq: PY-R1000
self: WSIPatchDataset,
img_path: str | Path,
mode: str = "wsi",
Expand Down Expand Up @@ -262,40 +262,17 @@ def __init__( # skipcq: PY-R1000 # noqa: PLR0915
raise ValueError(msg)

self.preproc_func = preproc_func
img_path = Path(img_path)
if mode == "wsi":
self.reader = WSIReader.open(img_path)
else:
logger.warning(
"WSIPatchDataset only reads image tile at "
'`units="baseline"` and `resolution=1.0`.',
stacklevel=2,
)
img = imread(img_path)
axes = "YXS"[: len(img.shape)]
# initialise metadata for VirtualWSIReader.
# here, we simulate a whole-slide image, but with a single level.
# ! should we expose this so that use can provide their metadata ?
metadata = WSIMeta(
mpp=np.array([1.0, 1.0]),
axes=axes,
objective_power=10,
slide_dimensions=np.array(img.shape[:2][::-1]),
level_downsamples=[1.0],
level_dimensions=[np.array(img.shape[:2][::-1])],
)
# infer value such that read if mask provided is through
# 'mpp' or 'power' as varying 'baseline' is locked atm
self.img_path = Path(img_path)
self.mode = mode
self.reader = None
reader = self._get_reader(self.img_path)
if mode != "wsi":
units = "mpp"
resolution = 1.0
self.reader = VirtualWSIReader(
img,
info=metadata,
)

# may decouple into misc ?
# the scaling factor will scale base level to requested read resolution/units
wsi_shape = self.reader.slide_dimensions(resolution=resolution, units=units)
wsi_shape = reader.slide_dimensions(resolution=resolution, units=units)
Copy link

Copilot AI Aug 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Creating a temporary reader object during initialization defeats the purpose of lazy initialization. This reader is only used to get slide dimensions but will be discarded, causing unnecessary overhead. Consider caching the slide dimensions or refactoring to avoid creating the reader twice.

Copilot uses AI. Check for mistakes.

# use all patches, as long as it overlaps source image
self.inputs = PatchExtractor.get_coordinates(
Expand All @@ -316,13 +293,13 @@ def __init__( # skipcq: PY-R1000 # noqa: PLR0915
mask = np.array(mask > 0, dtype=np.uint8)

mask_reader = VirtualWSIReader(mask)
mask_reader.info = self.reader.info
mask_reader.info = reader.info
elif auto_get_mask and mode == "wsi" and mask_path is None:
# if no mask provided and `wsi` mode, generate basic tissue
# mask on the fly
mask_reader = self.reader.tissue_mask(resolution=1.25, units="power")
mask_reader = reader.tissue_mask(resolution=1.25, units="power")
# ? will this mess up ?
mask_reader.info = self.reader.info
mask_reader.info = reader.info

if mask_reader is not None:
selected = PatchExtractor.filter_coordinates(
Expand All @@ -344,10 +321,44 @@ def __init__( # skipcq: PY-R1000 # noqa: PLR0915
# Perform check on the input
self._check_input_integrity(mode="wsi")

def _get_reader(self: WSIPatchDataset, img_path: str | Path) -> WSIReader:
"""Get a reader for the image."""
if self.mode == "wsi":
reader = WSIReader.open(img_path)
else:
logger.warning(
"WSIPatchDataset only reads image tile at "
'`units="baseline"` and `resolution=1.0`.',
stacklevel=2,
)
img = imread(img_path)
axes = "YXS"[: len(img.shape)]
# initialise metadata for VirtualWSIReader.
# here, we simulate a whole-slide image, but with a single level.
# ! should we expose this so that use can provide their metadata ?
metadata = WSIMeta(
mpp=np.array([1.0, 1.0]),
axes=axes,
objective_power=10,
slide_dimensions=np.array(img.shape[:2][::-1]),
level_downsamples=[1.0],
level_dimensions=[np.array(img.shape[:2][::-1])],
)
# infer value such that read if mask provided is through
# 'mpp' or 'power' as varying 'baseline' is locked atm
reader = VirtualWSIReader(
img,
info=metadata,
)
return reader

def __getitem__(self: WSIPatchDataset, idx: int) -> dict:
"""Get an item from the dataset."""
coords = self.inputs[idx]
# Read image patch from the whole-slide image
if self.reader is None:
Copy link

Copilot AI Aug 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The lazy initialization of self.reader is not thread-safe. Multiple threads could simultaneously check if self.reader is None and create multiple reader instances, potentially causing race conditions in a multi-threaded environment.

Copilot uses AI. Check for mistakes.
# only set the reader on first call so that it is initially picklable
self.reader = self._get_reader(self.img_path)
patch = self.reader.read_bounds(
coords,
resolution=self.resolution,
Expand Down