@@ -163,7 +163,7 @@ class WSIPatchDataset(dataset_abc.PatchDatasetABC):
163163
164164 """
165165
166- def __init__ ( # skipcq: PY-R1000 # noqa: PLR0915
166+ def __init__ ( # skipcq: PY-R1000
167167 self : WSIPatchDataset ,
168168 img_path : str | Path ,
169169 mode : str = "wsi" ,
@@ -262,40 +262,17 @@ def __init__( # skipcq: PY-R1000 # noqa: PLR0915
262262 raise ValueError (msg )
263263
264264 self .preproc_func = preproc_func
265- img_path = Path (img_path )
266- if mode == "wsi" :
267- self .reader = WSIReader .open (img_path )
268- else :
269- logger .warning (
270- "WSIPatchDataset only reads image tile at "
271- '`units="baseline"` and `resolution=1.0`.' ,
272- stacklevel = 2 ,
273- )
274- img = imread (img_path )
275- axes = "YXS" [: len (img .shape )]
276- # initialise metadata for VirtualWSIReader.
277- # here, we simulate a whole-slide image, but with a single level.
278- # ! should we expose this so that use can provide their metadata ?
279- metadata = WSIMeta (
280- mpp = np .array ([1.0 , 1.0 ]),
281- axes = axes ,
282- objective_power = 10 ,
283- slide_dimensions = np .array (img .shape [:2 ][::- 1 ]),
284- level_downsamples = [1.0 ],
285- level_dimensions = [np .array (img .shape [:2 ][::- 1 ])],
286- )
287- # infer value such that read if mask provided is through
288- # 'mpp' or 'power' as varying 'baseline' is locked atm
265+ self .img_path = Path (img_path )
266+ self .mode = mode
267+ self .reader = None
268+ reader = self ._get_reader (self .img_path )
269+ if mode != "wsi" :
289270 units = "mpp"
290271 resolution = 1.0
291- self .reader = VirtualWSIReader (
292- img ,
293- info = metadata ,
294- )
295272
296273 # may decouple into misc ?
297274 # the scaling factor will scale base level to requested read resolution/units
298- wsi_shape = self . reader .slide_dimensions (resolution = resolution , units = units )
275+ wsi_shape = reader .slide_dimensions (resolution = resolution , units = units )
299276
300277 # use all patches, as long as it overlaps source image
301278 self .inputs = PatchExtractor .get_coordinates (
@@ -316,13 +293,13 @@ def __init__( # skipcq: PY-R1000 # noqa: PLR0915
316293 mask = np .array (mask > 0 , dtype = np .uint8 )
317294
318295 mask_reader = VirtualWSIReader (mask )
319- mask_reader .info = self . reader .info
296+ mask_reader .info = reader .info
320297 elif auto_get_mask and mode == "wsi" and mask_path is None :
321298 # if no mask provided and `wsi` mode, generate basic tissue
322299 # mask on the fly
323- mask_reader = self . reader .tissue_mask (resolution = 1.25 , units = "power" )
300+ mask_reader = reader .tissue_mask (resolution = 1.25 , units = "power" )
324301 # ? will this mess up ?
325- mask_reader .info = self . reader .info
302+ mask_reader .info = reader .info
326303
327304 if mask_reader is not None :
328305 selected = PatchExtractor .filter_coordinates (
@@ -344,10 +321,44 @@ def __init__( # skipcq: PY-R1000 # noqa: PLR0915
344321 # Perform check on the input
345322 self ._check_input_integrity (mode = "wsi" )
346323
324+ def _get_reader (self : WSIPatchDataset , img_path : str | Path ) -> WSIReader :
325+ """Get a reader for the image."""
326+ if self .mode == "wsi" :
327+ reader = WSIReader .open (img_path )
328+ else :
329+ logger .warning (
330+ "WSIPatchDataset only reads image tile at "
331+ '`units="baseline"` and `resolution=1.0`.' ,
332+ stacklevel = 2 ,
333+ )
334+ img = imread (img_path )
335+ axes = "YXS" [: len (img .shape )]
336+ # initialise metadata for VirtualWSIReader.
337+ # here, we simulate a whole-slide image, but with a single level.
338+ # ! should we expose this so that use can provide their metadata ?
339+ metadata = WSIMeta (
340+ mpp = np .array ([1.0 , 1.0 ]),
341+ axes = axes ,
342+ objective_power = 10 ,
343+ slide_dimensions = np .array (img .shape [:2 ][::- 1 ]),
344+ level_downsamples = [1.0 ],
345+ level_dimensions = [np .array (img .shape [:2 ][::- 1 ])],
346+ )
347+ # infer value such that read if mask provided is through
348+ # 'mpp' or 'power' as varying 'baseline' is locked atm
349+ reader = VirtualWSIReader (
350+ img ,
351+ info = metadata ,
352+ )
353+ return reader
354+
347355 def __getitem__ (self : WSIPatchDataset , idx : int ) -> dict :
348356 """Get an item from the dataset."""
349357 coords = self .inputs [idx ]
350358 # Read image patch from the whole-slide image
359+ if self .reader is None :
360+ # only set the reader on first call so that it is initially picklable
361+ self .reader = self ._get_reader (self .img_path )
351362 patch = self .reader .read_bounds (
352363 coords ,
353364 resolution = self .resolution ,
0 commit comments