Skip to content

Commit 82d6a7c

Browse files
authored
Merge branch 'dev-define-engines-abc' into dev-convert-modelabc-to-torch-nn-model
2 parents a89e529 + d2381c0 commit 82d6a7c

File tree

2 files changed

+45
-33
lines changed

2 files changed

+45
-33
lines changed

tiatoolbox/models/dataset/dataset_abc.py

Lines changed: 44 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -458,40 +458,18 @@ def __init__( # skipcq: PY-R1000 # noqa: PLR0915
458458
raise ValueError(msg)
459459

460460
self.preproc_func = preproc_func
461-
img_path = Path(img_path)
462-
if mode == "wsi":
463-
self.reader = WSIReader.open(img_path)
464-
else:
465-
logger.warning(
466-
"WSIPatchDataset only reads image tile at "
467-
'`units="baseline"` and `resolution=1.0`.',
468-
stacklevel=2,
469-
)
470-
img = imread(img_path)
471-
axes = "YXS"[: len(img.shape)]
472-
# initialise metadata for VirtualWSIReader.
473-
# here, we simulate a whole-slide image, but with a single level.
474-
# ! should we expose this so that use can provide their metadata ?
475-
metadata = WSIMeta(
476-
mpp=np.array([1.0, 1.0]),
477-
axes=axes,
478-
objective_power=10,
479-
slide_dimensions=np.array(img.shape[:2][::-1]),
480-
level_downsamples=[1.0],
481-
level_dimensions=[np.array(img.shape[:2][::-1])],
482-
)
483-
# infer value such that read if mask provided is through
484-
# 'mpp' or 'power' as varying 'baseline' is locked atm
461+
self.img_path = Path(img_path)
462+
self.mode = mode
463+
self.reader = None
464+
reader = self._get_reader(self.img_path)
465+
if mode != "wsi":
485466
units = "mpp"
486467
resolution = 1.0
487-
self.reader = VirtualWSIReader(
488-
img,
489-
info=metadata,
490-
)
491468

492469
# may decouple into misc ?
493470
# the scaling factor will scale base level to requested read resolution/units
494-
wsi_shape = self.reader.slide_dimensions(resolution=resolution, units=units)
471+
wsi_shape = reader.slide_dimensions(resolution=resolution, units=units)
472+
self.reader_info = reader.info
495473

496474
# use all patches, as long as it overlaps source image
497475
self.inputs = PatchExtractor.get_coordinates(
@@ -512,13 +490,13 @@ def __init__( # skipcq: PY-R1000 # noqa: PLR0915
512490
mask = np.array(mask > 0, dtype=np.uint8)
513491

514492
mask_reader = VirtualWSIReader(mask)
515-
mask_reader.info = self.reader.info
493+
mask_reader.info = reader.info
516494
elif auto_get_mask and mode == "wsi" and mask_path is None:
517495
# if no mask provided and `wsi` mode, generate basic tissue
518496
# mask on the fly
519-
mask_reader = self.reader.tissue_mask(resolution=1.25, units="power")
497+
mask_reader = reader.tissue_mask(resolution=1.25, units="power")
520498
# ? will this mess up ?
521-
mask_reader.info = self.reader.info
499+
mask_reader.info = reader.info
522500

523501
if mask_reader is not None:
524502
selected = PatchExtractor.filter_coordinates(
@@ -540,10 +518,44 @@ def __init__( # skipcq: PY-R1000 # noqa: PLR0915
540518
# Perform check on the input
541519
self._check_input_integrity(mode="wsi")
542520

521+
def _get_reader(self: WSIPatchDataset, img_path: str | Path) -> WSIReader:
522+
"""Get a reader for the image."""
523+
if self.mode == "wsi":
524+
reader = WSIReader.open(img_path)
525+
else:
526+
logger.warning(
527+
"WSIPatchDataset only reads image tile at "
528+
'`units="baseline"` and `resolution=1.0`.',
529+
stacklevel=2,
530+
)
531+
img = imread(img_path)
532+
axes = "YXS"[: len(img.shape)]
533+
# initialise metadata for VirtualWSIReader.
534+
# here, we simulate a whole-slide image, but with a single level.
535+
# ! should we expose this so that use can provide their metadata ?
536+
metadata = WSIMeta(
537+
mpp=np.array([1.0, 1.0]),
538+
axes=axes,
539+
objective_power=10,
540+
slide_dimensions=np.array(img.shape[:2][::-1]),
541+
level_downsamples=[1.0],
542+
level_dimensions=[np.array(img.shape[:2][::-1])],
543+
)
544+
# infer value such that read if mask provided is through
545+
# 'mpp' or 'power' as varying 'baseline' is locked atm
546+
reader = VirtualWSIReader(
547+
img,
548+
info=metadata,
549+
)
550+
return reader
551+
543552
def __getitem__(self: WSIPatchDataset, idx: int) -> dict:
544553
"""Get an item from the dataset."""
545554
coords = self.inputs[idx]
546555
# Read image patch from the whole-slide image
556+
if self.reader is None:
557+
# only set the reader on first call so that it is initially picklable
558+
self.reader = self._get_reader(self.img_path)
547559
patch = self.reader.read_bounds(
548560
coords,
549561
resolution=self.resolution,

tiatoolbox/models/engine/engine_abc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1023,7 +1023,7 @@ def _calculate_scale_factor(dataloader: DataLoader) -> float | tuple[float, floa
10231023
# equal to dataloader resolution.
10241024

10251025
if dataloader_units in ["mpp", "level", "power"]:
1026-
wsimeta_dict = dataloader.dataset.reader.info.as_dict()
1026+
wsimeta_dict = dataloader.dataset.reader_info.as_dict()
10271027

10281028
if dataloader_units == "mpp":
10291029
slide_resolution = wsimeta_dict[dataloader_units]

0 commit comments

Comments
 (0)