diff --git a/examples/configs/fit_example.yml b/examples/configs/fit_example.yml index 30c76688e..35851472e 100644 --- a/examples/configs/fit_example.yml +++ b/examples/configs/fit_example.yml @@ -41,7 +41,8 @@ data: batch_size: 32 num_workers: 16 yx_patch_size: [256, 256] - normalizations: + pyramid_resolution: "0" + normalizations: - class_path: viscy.transforms.NormalizeSampled init_args: keys: [source] @@ -92,3 +93,4 @@ data: sigma_y: [0.25, 1.5] sigma_x: [0.25, 1.5] caching: false + diff --git a/examples/configs/predict_example.yml b/examples/configs/predict_example.yml index 9cc8f5856..2d04d645f 100644 --- a/examples/configs/predict_example.yml +++ b/examples/configs/predict_example.yml @@ -66,5 +66,6 @@ predict: - 256 - 256 caching: false + pyramid_resolution: "0" return_predictions: false ckpt_path: null diff --git a/examples/configs/test_example.yml b/examples/configs/test_example.yml index 6dac8384b..b1d417271 100644 --- a/examples/configs/test_example.yml +++ b/examples/configs/test_example.yml @@ -66,5 +66,7 @@ data: - 256 caching: false ground_truth_masks: null + pyramid_resolution: "0" ckpt_path: null verbose: true +a \ No newline at end of file diff --git a/viscy/data/hcs.py b/viscy/data/hcs.py index 52d0e5103..2df62a228 100644 --- a/viscy/data/hcs.py +++ b/viscy/data/hcs.py @@ -104,8 +104,12 @@ class SlidingWindowDataset(Dataset): :param ChannelMap channels: source and target channel names, e.g. ``{'source': 'Phase', 'target': ['Nuclei', 'Membrane']}`` :param int z_window_size: Z window size of the 2.5D U-Net, 1 for 2D + :param str array_key: + Name of the image arrays (multiscales level), by default "0" :param DictTransform | None transform: a callable that transforms data, defaults to None + :param bool load_normalization_metadata: + whether to load normalization metadata, defaults to True """ def __init__( @@ -113,7 +117,9 @@ def __init__( positions: list[Position], channels: ChannelMap, z_window_size: int, + array_key: str = "0", transform: DictTransform | None = None, + load_normalization_metadata: bool = True, ) -> None: super().__init__() self.positions = positions @@ -128,7 +134,9 @@ def __init__( ) self.z_window_size = z_window_size self.transform = transform + self.array_key = array_key self._get_windows() + self.load_normalization_metadata = load_normalization_metadata def _get_windows(self) -> None: """Count the sliding windows along T and Z, @@ -138,7 +146,7 @@ def _get_windows(self) -> None: self.window_arrays = [] self.window_norm_meta: list[NormMeta | None] = [] for fov in self.positions: - img_arr: ImageArray = fov["0"] + img_arr: ImageArray = fov[str(self.array_key)] ts = img_arr.frames zs = img_arr.slices - self.z_window_size + 1 if zs < 1: @@ -225,10 +233,11 @@ def __getitem__(self, index: int) -> Sample: sample = { "index": sample_index, "source": self._stack_channels(sample_images, "source"), - "norm_meta": norm_meta, } if self.target_ch_idx is not None: sample["target"] = self._stack_channels(sample_images, "target") + if self.load_normalization_metadata: + sample["norm_meta"] = norm_meta return sample @@ -326,6 +335,8 @@ class HCSDataModule(LightningDataModule): prefetch_factor : int or None, optional Number of samples loaded in advance by each worker during fitting, defaults to None (2 per PyTorch default). + array_key : str, optional + Name of the image arrays (multiscales level), by default "0" """ def __init__( @@ -345,6 +356,7 @@ def __init__( ground_truth_masks: Path | None = None, persistent_workers=False, prefetch_factor=None, + array_key: str = "0", ): super().__init__() self.data_path = Path(data_path) @@ -363,6 +375,7 @@ def __init__( self.prepare_data_per_node = True self.persistent_workers = persistent_workers self.prefetch_factor = prefetch_factor + self.array_key = array_key @property def cache_path(self): @@ -419,6 +432,7 @@ def _base_dataset_settings(self) -> dict[str, dict[str, list[str]] | int]: return { "channels": {"source": self.source_channel}, "z_window_size": self.z_window_size, + "array_key": self.array_key, } def setup(self, stage: Literal["fit", "validate", "test", "predict"]):