Skip to content

Commit 6721567

Browse files
authored
Merge branch 'develop' into enhance-hf-weights
2 parents 41ba556 + d5c1995 commit 6721567

File tree

1 file changed

+44
-33
lines changed

1 file changed

+44
-33
lines changed

tiatoolbox/models/dataset/classification.py

Lines changed: 44 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ class WSIPatchDataset(dataset_abc.PatchDatasetABC):
163163
164164
"""
165165

166-
def __init__( # skipcq: PY-R1000 # noqa: PLR0915
166+
def __init__( # skipcq: PY-R1000
167167
self: WSIPatchDataset,
168168
img_path: str | Path,
169169
mode: str = "wsi",
@@ -262,40 +262,17 @@ def __init__( # skipcq: PY-R1000 # noqa: PLR0915
262262
raise ValueError(msg)
263263

264264
self.preproc_func = preproc_func
265-
img_path = Path(img_path)
266-
if mode == "wsi":
267-
self.reader = WSIReader.open(img_path)
268-
else:
269-
logger.warning(
270-
"WSIPatchDataset only reads image tile at "
271-
'`units="baseline"` and `resolution=1.0`.',
272-
stacklevel=2,
273-
)
274-
img = imread(img_path)
275-
axes = "YXS"[: len(img.shape)]
276-
# initialise metadata for VirtualWSIReader.
277-
# here, we simulate a whole-slide image, but with a single level.
278-
# ! should we expose this so that use can provide their metadata ?
279-
metadata = WSIMeta(
280-
mpp=np.array([1.0, 1.0]),
281-
axes=axes,
282-
objective_power=10,
283-
slide_dimensions=np.array(img.shape[:2][::-1]),
284-
level_downsamples=[1.0],
285-
level_dimensions=[np.array(img.shape[:2][::-1])],
286-
)
287-
# infer value such that read if mask provided is through
288-
# 'mpp' or 'power' as varying 'baseline' is locked atm
265+
self.img_path = Path(img_path)
266+
self.mode = mode
267+
self.reader = None
268+
reader = self._get_reader(self.img_path)
269+
if mode != "wsi":
289270
units = "mpp"
290271
resolution = 1.0
291-
self.reader = VirtualWSIReader(
292-
img,
293-
info=metadata,
294-
)
295272

296273
# may decouple into misc ?
297274
# the scaling factor will scale base level to requested read resolution/units
298-
wsi_shape = self.reader.slide_dimensions(resolution=resolution, units=units)
275+
wsi_shape = reader.slide_dimensions(resolution=resolution, units=units)
299276

300277
# use all patches, as long as it overlaps source image
301278
self.inputs = PatchExtractor.get_coordinates(
@@ -316,13 +293,13 @@ def __init__( # skipcq: PY-R1000 # noqa: PLR0915
316293
mask = np.array(mask > 0, dtype=np.uint8)
317294

318295
mask_reader = VirtualWSIReader(mask)
319-
mask_reader.info = self.reader.info
296+
mask_reader.info = reader.info
320297
elif auto_get_mask and mode == "wsi" and mask_path is None:
321298
# if no mask provided and `wsi` mode, generate basic tissue
322299
# mask on the fly
323-
mask_reader = self.reader.tissue_mask(resolution=1.25, units="power")
300+
mask_reader = reader.tissue_mask(resolution=1.25, units="power")
324301
# ? will this mess up ?
325-
mask_reader.info = self.reader.info
302+
mask_reader.info = reader.info
326303

327304
if mask_reader is not None:
328305
selected = PatchExtractor.filter_coordinates(
@@ -344,10 +321,44 @@ def __init__( # skipcq: PY-R1000 # noqa: PLR0915
344321
# Perform check on the input
345322
self._check_input_integrity(mode="wsi")
346323

324+
def _get_reader(self: WSIPatchDataset, img_path: str | Path) -> WSIReader:
325+
"""Get a reader for the image."""
326+
if self.mode == "wsi":
327+
reader = WSIReader.open(img_path)
328+
else:
329+
logger.warning(
330+
"WSIPatchDataset only reads image tile at "
331+
'`units="baseline"` and `resolution=1.0`.',
332+
stacklevel=2,
333+
)
334+
img = imread(img_path)
335+
axes = "YXS"[: len(img.shape)]
336+
# initialise metadata for VirtualWSIReader.
337+
# here, we simulate a whole-slide image, but with a single level.
338+
# ! should we expose this so that use can provide their metadata ?
339+
metadata = WSIMeta(
340+
mpp=np.array([1.0, 1.0]),
341+
axes=axes,
342+
objective_power=10,
343+
slide_dimensions=np.array(img.shape[:2][::-1]),
344+
level_downsamples=[1.0],
345+
level_dimensions=[np.array(img.shape[:2][::-1])],
346+
)
347+
# infer value such that read if mask provided is through
348+
# 'mpp' or 'power' as varying 'baseline' is locked atm
349+
reader = VirtualWSIReader(
350+
img,
351+
info=metadata,
352+
)
353+
return reader
354+
347355
def __getitem__(self: WSIPatchDataset, idx: int) -> dict:
348356
"""Get an item from the dataset."""
349357
coords = self.inputs[idx]
350358
# Read image patch from the whole-slide image
359+
if self.reader is None:
360+
# only set the reader on first call so that it is initially picklable
361+
self.reader = self._get_reader(self.img_path)
351362
patch = self.reader.read_bounds(
352363
coords,
353364
resolution=self.resolution,

0 commit comments

Comments
 (0)