Skip to content

Commit d5c1995

Browse files
authored
⚡ Make WSIPatchDataset Pickleable to Support Windows Multithreading (#947)
This PR makes the WSIPatchDataset class picklable by delaying the creation of the reader object until the first call to `__getitem__`. This enables the use of multiple loader workers on Windows without errors and provides significant performance improvements. - Delays reader object instantiation to the first `__getitem__` call instead of during initialization - Extracts reader creation logic into a separate `_get_reader` method - Stores image path and mode as instance variables for lazy initialization Speedup for the WSI prediction cell of the patch_prediction example notebook: 2min 48 sec with 0 loader workers -> 1min 13 sec with 4 workers. Note: this PR doesn't have any effect for Linux as the multi-threading already works fine there because Linux multithreading doesn't require things to be pickleable
1 parent c1eb36c commit d5c1995

File tree

1 file changed

+44
-33
lines changed

1 file changed

+44
-33
lines changed

tiatoolbox/models/dataset/classification.py

Lines changed: 44 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ class WSIPatchDataset(dataset_abc.PatchDatasetABC):
163163
164164
"""
165165

166-
def __init__( # skipcq: PY-R1000 # noqa: PLR0915
166+
def __init__( # skipcq: PY-R1000
167167
self: WSIPatchDataset,
168168
img_path: str | Path,
169169
mode: str = "wsi",
@@ -262,40 +262,17 @@ def __init__( # skipcq: PY-R1000 # noqa: PLR0915
262262
raise ValueError(msg)
263263

264264
self.preproc_func = preproc_func
265-
img_path = Path(img_path)
266-
if mode == "wsi":
267-
self.reader = WSIReader.open(img_path)
268-
else:
269-
logger.warning(
270-
"WSIPatchDataset only reads image tile at "
271-
'`units="baseline"` and `resolution=1.0`.',
272-
stacklevel=2,
273-
)
274-
img = imread(img_path)
275-
axes = "YXS"[: len(img.shape)]
276-
# initialise metadata for VirtualWSIReader.
277-
# here, we simulate a whole-slide image, but with a single level.
278-
# ! should we expose this so that use can provide their metadata ?
279-
metadata = WSIMeta(
280-
mpp=np.array([1.0, 1.0]),
281-
axes=axes,
282-
objective_power=10,
283-
slide_dimensions=np.array(img.shape[:2][::-1]),
284-
level_downsamples=[1.0],
285-
level_dimensions=[np.array(img.shape[:2][::-1])],
286-
)
287-
# infer value such that read if mask provided is through
288-
# 'mpp' or 'power' as varying 'baseline' is locked atm
265+
self.img_path = Path(img_path)
266+
self.mode = mode
267+
self.reader = None
268+
reader = self._get_reader(self.img_path)
269+
if mode != "wsi":
289270
units = "mpp"
290271
resolution = 1.0
291-
self.reader = VirtualWSIReader(
292-
img,
293-
info=metadata,
294-
)
295272

296273
# may decouple into misc ?
297274
# the scaling factor will scale base level to requested read resolution/units
298-
wsi_shape = self.reader.slide_dimensions(resolution=resolution, units=units)
275+
wsi_shape = reader.slide_dimensions(resolution=resolution, units=units)
299276

300277
# use all patches, as long as it overlaps source image
301278
self.inputs = PatchExtractor.get_coordinates(
@@ -316,13 +293,13 @@ def __init__( # skipcq: PY-R1000 # noqa: PLR0915
316293
mask = np.array(mask > 0, dtype=np.uint8)
317294

318295
mask_reader = VirtualWSIReader(mask)
319-
mask_reader.info = self.reader.info
296+
mask_reader.info = reader.info
320297
elif auto_get_mask and mode == "wsi" and mask_path is None:
321298
# if no mask provided and `wsi` mode, generate basic tissue
322299
# mask on the fly
323-
mask_reader = self.reader.tissue_mask(resolution=1.25, units="power")
300+
mask_reader = reader.tissue_mask(resolution=1.25, units="power")
324301
# ? will this mess up ?
325-
mask_reader.info = self.reader.info
302+
mask_reader.info = reader.info
326303

327304
if mask_reader is not None:
328305
selected = PatchExtractor.filter_coordinates(
@@ -344,10 +321,44 @@ def __init__( # skipcq: PY-R1000 # noqa: PLR0915
344321
# Perform check on the input
345322
self._check_input_integrity(mode="wsi")
346323

324+
def _get_reader(self: WSIPatchDataset, img_path: str | Path) -> WSIReader:
325+
"""Get a reader for the image."""
326+
if self.mode == "wsi":
327+
reader = WSIReader.open(img_path)
328+
else:
329+
logger.warning(
330+
"WSIPatchDataset only reads image tile at "
331+
'`units="baseline"` and `resolution=1.0`.',
332+
stacklevel=2,
333+
)
334+
img = imread(img_path)
335+
axes = "YXS"[: len(img.shape)]
336+
# initialise metadata for VirtualWSIReader.
337+
# here, we simulate a whole-slide image, but with a single level.
338+
# ! should we expose this so that use can provide their metadata ?
339+
metadata = WSIMeta(
340+
mpp=np.array([1.0, 1.0]),
341+
axes=axes,
342+
objective_power=10,
343+
slide_dimensions=np.array(img.shape[:2][::-1]),
344+
level_downsamples=[1.0],
345+
level_dimensions=[np.array(img.shape[:2][::-1])],
346+
)
347+
# infer value such that read if mask provided is through
348+
# 'mpp' or 'power' as varying 'baseline' is locked atm
349+
reader = VirtualWSIReader(
350+
img,
351+
info=metadata,
352+
)
353+
return reader
354+
347355
def __getitem__(self: WSIPatchDataset, idx: int) -> dict:
348356
"""Get an item from the dataset."""
349357
coords = self.inputs[idx]
350358
# Read image patch from the whole-slide image
359+
if self.reader is None:
360+
# only set the reader on first call so that it is initially picklable
361+
self.reader = self._get_reader(self.img_path)
351362
patch = self.reader.read_bounds(
352363
coords,
353364
resolution=self.resolution,

0 commit comments

Comments
 (0)