@@ -458,40 +458,18 @@ def __init__( # skipcq: PY-R1000 # noqa: PLR0915
458458 raise ValueError (msg )
459459
460460 self .preproc_func = preproc_func
461- img_path = Path (img_path )
462- if mode == "wsi" :
463- self .reader = WSIReader .open (img_path )
464- else :
465- logger .warning (
466- "WSIPatchDataset only reads image tile at "
467- '`units="baseline"` and `resolution=1.0`.' ,
468- stacklevel = 2 ,
469- )
470- img = imread (img_path )
471- axes = "YXS" [: len (img .shape )]
472- # initialise metadata for VirtualWSIReader.
473- # here, we simulate a whole-slide image, but with a single level.
474- # ! should we expose this so that use can provide their metadata ?
475- metadata = WSIMeta (
476- mpp = np .array ([1.0 , 1.0 ]),
477- axes = axes ,
478- objective_power = 10 ,
479- slide_dimensions = np .array (img .shape [:2 ][::- 1 ]),
480- level_downsamples = [1.0 ],
481- level_dimensions = [np .array (img .shape [:2 ][::- 1 ])],
482- )
483- # infer value such that read if mask provided is through
484- # 'mpp' or 'power' as varying 'baseline' is locked atm
461+ self .img_path = Path (img_path )
462+ self .mode = mode
463+ self .reader = None
464+ reader = self ._get_reader (self .img_path )
465+ if mode != "wsi" :
485466 units = "mpp"
486467 resolution = 1.0
487- self .reader = VirtualWSIReader (
488- img ,
489- info = metadata ,
490- )
491468
492469 # may decouple into misc ?
493470 # the scaling factor will scale base level to requested read resolution/units
494- wsi_shape = self .reader .slide_dimensions (resolution = resolution , units = units )
471+ wsi_shape = reader .slide_dimensions (resolution = resolution , units = units )
472+ self .reader_info = reader .info
495473
496474 # use all patches, as long as it overlaps source image
497475 self .inputs = PatchExtractor .get_coordinates (
@@ -512,13 +490,13 @@ def __init__( # skipcq: PY-R1000 # noqa: PLR0915
512490 mask = np .array (mask > 0 , dtype = np .uint8 )
513491
514492 mask_reader = VirtualWSIReader (mask )
515- mask_reader .info = self . reader .info
493+ mask_reader .info = reader .info
516494 elif auto_get_mask and mode == "wsi" and mask_path is None :
517495 # if no mask provided and `wsi` mode, generate basic tissue
518496 # mask on the fly
519- mask_reader = self . reader .tissue_mask (resolution = 1.25 , units = "power" )
497+ mask_reader = reader .tissue_mask (resolution = 1.25 , units = "power" )
520498 # ? will this mess up ?
521- mask_reader .info = self . reader .info
499+ mask_reader .info = reader .info
522500
523501 if mask_reader is not None :
524502 selected = PatchExtractor .filter_coordinates (
@@ -540,10 +518,44 @@ def __init__( # skipcq: PY-R1000 # noqa: PLR0915
540518 # Perform check on the input
541519 self ._check_input_integrity (mode = "wsi" )
542520
521+ def _get_reader (self : WSIPatchDataset , img_path : str | Path ) -> WSIReader :
522+ """Get a reader for the image."""
523+ if self .mode == "wsi" :
524+ reader = WSIReader .open (img_path )
525+ else :
526+ logger .warning (
527+ "WSIPatchDataset only reads image tile at "
528+ '`units="baseline"` and `resolution=1.0`.' ,
529+ stacklevel = 2 ,
530+ )
531+ img = imread (img_path )
532+ axes = "YXS" [: len (img .shape )]
533+ # initialise metadata for VirtualWSIReader.
534+ # here, we simulate a whole-slide image, but with a single level.
535+ # ! should we expose this so that use can provide their metadata ?
536+ metadata = WSIMeta (
537+ mpp = np .array ([1.0 , 1.0 ]),
538+ axes = axes ,
539+ objective_power = 10 ,
540+ slide_dimensions = np .array (img .shape [:2 ][::- 1 ]),
541+ level_downsamples = [1.0 ],
542+ level_dimensions = [np .array (img .shape [:2 ][::- 1 ])],
543+ )
544+ # infer value such that read if mask provided is through
545+ # 'mpp' or 'power' as varying 'baseline' is locked atm
546+ reader = VirtualWSIReader (
547+ img ,
548+ info = metadata ,
549+ )
550+ return reader
551+
543552 def __getitem__ (self : WSIPatchDataset , idx : int ) -> dict :
544553 """Get an item from the dataset."""
545554 coords = self .inputs [idx ]
546555 # Read image patch from the whole-slide image
556+ if self .reader is None :
557+ # only set the reader on first call so that it is initially picklable
558+ self .reader = self ._get_reader (self .img_path )
547559 patch = self .reader .read_bounds (
548560 coords ,
549561 resolution = self .resolution ,
0 commit comments