|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +import os |
| 4 | +import tempfile |
| 5 | +from logging import getLogger |
| 6 | +from pathlib import Path |
| 7 | +from typing import TYPE_CHECKING, Literal |
| 8 | + |
| 9 | +import numpy as np |
| 10 | +import torch |
| 11 | +from iohub.ngff import Plate, Position, open_ome_zarr |
| 12 | +from monai.data.meta_obj import set_track_meta |
| 13 | +from monai.transforms.compose import Compose |
| 14 | +from tensordict.memmap import MemoryMappedTensor |
| 15 | +from torch import Tensor |
| 16 | +from torch.multiprocessing import Manager |
| 17 | +from torch.utils.data import Dataset |
| 18 | + |
| 19 | +from viscy.data.gpu_aug import GPUTransformDataModule, SelectWell |
| 20 | +from viscy.data.hcs import _ensure_channel_list, _read_norm_meta |
| 21 | +from viscy.data.typing import DictTransform, NormMeta |
| 22 | + |
| 23 | +if TYPE_CHECKING: |
| 24 | + from multiprocessing.managers import DictProxy |
| 25 | + |
| 26 | +_logger = getLogger("lightning.pytorch") |
| 27 | + |
| 28 | +_CacheMetadata = tuple[Position, int, NormMeta | None] |
| 29 | + |
| 30 | + |
| 31 | +class MmappedDataset(Dataset): |
| 32 | + def __init__( |
| 33 | + self, |
| 34 | + positions: list[Position], |
| 35 | + channel_names: list[str], |
| 36 | + cache_map: DictProxy, |
| 37 | + buffer: MemoryMappedTensor, |
| 38 | + preprocess_transforms: Compose | None = None, |
| 39 | + cpu_transform: Compose | None = None, |
| 40 | + array_key: str = "0", |
| 41 | + load_normalization_metadata: bool = True, |
| 42 | + ): |
| 43 | + key = 0 |
| 44 | + self._metadata_map: dict[int, _CacheMetadata] = {} |
| 45 | + for position in positions: |
| 46 | + img = position[array_key] |
| 47 | + norm_meta = _read_norm_meta(position) |
| 48 | + for time_idx in range(img.frames): |
| 49 | + cache_map[key] = None |
| 50 | + self._metadata_map[key] = (position, time_idx, norm_meta) |
| 51 | + key += 1 |
| 52 | + self.channels = {ch: position.get_channel_index(ch) for ch in channel_names} |
| 53 | + self.array_key = array_key |
| 54 | + self._buffer = buffer |
| 55 | + self._cache_map = cache_map |
| 56 | + self.preprocess_transforms = preprocess_transforms |
| 57 | + self.cpu_transform = cpu_transform |
| 58 | + self.load_normalization_metadata = load_normalization_metadata |
| 59 | + |
| 60 | + def __len__(self) -> int: |
| 61 | + return len(self._metadata_map) |
| 62 | + |
| 63 | + def _split_channels(self, volume: Tensor) -> dict[str, Tensor]: |
| 64 | + return {name: img[None] for name, img in zip(self.channels.keys(), volume)} |
| 65 | + |
| 66 | + def _preprocess_volume(self, volume: Tensor, norm_meta) -> Tensor: |
| 67 | + if self.preprocess_transforms: |
| 68 | + orig_shape = volume.shape |
| 69 | + sample = self._split_channels(volume) |
| 70 | + if self.load_normalization_metadata: |
| 71 | + sample["norm_meta"] = norm_meta |
| 72 | + sample = self.preprocess_transforms(sample) |
| 73 | + volume = torch.cat([sample[name] for name in self.channels.keys()], dim=0) |
| 74 | + assert volume.shape == orig_shape, (volume.shape, orig_shape, sample.keys()) |
| 75 | + return volume |
| 76 | + |
| 77 | + def __getitem__(self, idx: int) -> dict[str, Tensor]: |
| 78 | + position, time_idx, norm_meta = self._metadata_map[idx] |
| 79 | + if not self._cache_map[idx]: |
| 80 | + _logger.debug(f"Loading volume for index {idx}") |
| 81 | + volume = torch.from_numpy( |
| 82 | + position[self.array_key] |
| 83 | + .oindex[time_idx, list(self.channels.values())] |
| 84 | + .astype(np.float32) |
| 85 | + ) |
| 86 | + volume = self._preprocess_volume(volume, norm_meta) |
| 87 | + _logger.debug(f"Caching for index {idx}") |
| 88 | + self._cache_map[idx] = True |
| 89 | + self._buffer[idx] = volume |
| 90 | + else: |
| 91 | + _logger.debug(f"Using cached volume for index {idx}") |
| 92 | + volume = self._buffer[idx] |
| 93 | + sample = self._split_channels(volume) |
| 94 | + if self.cpu_transform: |
| 95 | + sample = self.cpu_transform(sample) |
| 96 | + if not isinstance(sample, list): |
| 97 | + sample = [sample] |
| 98 | + return sample |
| 99 | + |
| 100 | + |
| 101 | +class MmappedDataModule(GPUTransformDataModule, SelectWell): |
| 102 | + """Data module for cached OME-Zarr arrays. |
| 103 | +
|
| 104 | + Parameters |
| 105 | + ---------- |
| 106 | + data_path : Path |
| 107 | + Path to the HCS OME-Zarr dataset. |
| 108 | + channels : str | list[str] |
| 109 | + Channel names to load. |
| 110 | + batch_size : int |
| 111 | + Batch size for training and validation. |
| 112 | + num_workers : int |
| 113 | + Number of workers for data-loaders. |
| 114 | + split_ratio : float |
| 115 | + Fraction of the FOVs used for the training split. |
| 116 | + The rest will be used for validation. |
| 117 | + train_cpu_transforms : list[DictTransform] |
| 118 | + Transforms to be applied on the CPU during training. |
| 119 | + val_cpu_transforms : list[DictTransform] |
| 120 | + Transforms to be applied on the CPU during validation. |
| 121 | + train_gpu_transforms : list[DictTransform] |
| 122 | + Transforms to be applied on the GPU during training. |
| 123 | + val_gpu_transforms : list[DictTransform] |
| 124 | + Transforms to be applied on the GPU during validation. |
| 125 | + pin_memory : bool, optional |
| 126 | + Use page-locked memory in data-loaders, by default True |
| 127 | + prefetch_factor : int | None, optional |
| 128 | + Prefetching ratio for the torch dataloader, by default None |
| 129 | + array_key : str, optional |
| 130 | + Name of the image arrays (multiscales level), by default "0" |
| 131 | + scratch_dir : Path | None, optional |
| 132 | + Path to the scratch directory, |
| 133 | + by default None (use OS temporary data directory) |
| 134 | + include_wells : list[str] | None, optional |
| 135 | + Include only a subset of wells, by default None (include all wells) |
| 136 | + exclude_fovs : list[str] | None, optional |
| 137 | + Exclude FOVs, by default None (do not exclude any FOVs) |
| 138 | + """ |
| 139 | + |
| 140 | + def __init__( |
| 141 | + self, |
| 142 | + data_path: Path, |
| 143 | + channels: str | list[str], |
| 144 | + batch_size: int, |
| 145 | + num_workers: int, |
| 146 | + split_ratio: float, |
| 147 | + preprocess_transforms: list[DictTransform], |
| 148 | + train_cpu_transforms: list[DictTransform], |
| 149 | + val_cpu_transforms: list[DictTransform], |
| 150 | + train_gpu_transforms: list[DictTransform], |
| 151 | + val_gpu_transforms: list[DictTransform], |
| 152 | + pin_memory: bool = True, |
| 153 | + prefetch_factor: int | None = None, |
| 154 | + array_key: str = "0", |
| 155 | + scratch_dir: Path | None = None, |
| 156 | + include_wells: list[str] | None = None, |
| 157 | + exclude_fovs: list[str] | None = None, |
| 158 | + ): |
| 159 | + super().__init__() |
| 160 | + self.data_path = Path(data_path) |
| 161 | + self.channels = _ensure_channel_list(channels) |
| 162 | + self.batch_size = batch_size |
| 163 | + self.num_workers = num_workers |
| 164 | + self.split_ratio = split_ratio |
| 165 | + self._preprocessing_transforms = Compose(preprocess_transforms) |
| 166 | + self._train_cpu_transforms = Compose(train_cpu_transforms) |
| 167 | + self._val_cpu_transforms = Compose(val_cpu_transforms) |
| 168 | + self._train_gpu_transforms = Compose(train_gpu_transforms) |
| 169 | + self._val_gpu_transforms = Compose(val_gpu_transforms) |
| 170 | + self.pin_memory = pin_memory |
| 171 | + self.array_key = array_key |
| 172 | + self.scratch_dir = scratch_dir |
| 173 | + self._include_wells = include_wells |
| 174 | + self._exclude_fovs = exclude_fovs |
| 175 | + self.prepare_data_per_node = True |
| 176 | + self.prefetch_factor = prefetch_factor if self.num_workers > 0 else None |
| 177 | + |
| 178 | + @property |
| 179 | + def preprocessing_transforms(self) -> Compose: |
| 180 | + return self._preprocessing_transforms |
| 181 | + |
| 182 | + @property |
| 183 | + def train_cpu_transforms(self) -> Compose: |
| 184 | + return self._train_cpu_transforms |
| 185 | + |
| 186 | + @property |
| 187 | + def train_gpu_transforms(self) -> Compose: |
| 188 | + return self._train_gpu_transforms |
| 189 | + |
| 190 | + @property |
| 191 | + def val_cpu_transforms(self) -> Compose: |
| 192 | + return self._val_cpu_transforms |
| 193 | + |
| 194 | + @property |
| 195 | + def val_gpu_transforms(self) -> Compose: |
| 196 | + return self._val_gpu_transforms |
| 197 | + |
| 198 | + @property |
| 199 | + def cache_dir(self) -> Path: |
| 200 | + scratch_dir = self.scratch_dir or Path(tempfile.gettempdir()) |
| 201 | + cache_dir = Path( |
| 202 | + scratch_dir, |
| 203 | + os.getenv("SLURM_JOB_ID", "viscy_cache"), |
| 204 | + str( |
| 205 | + torch.distributed.get_rank() |
| 206 | + if torch.distributed.is_initialized() |
| 207 | + else 0 |
| 208 | + ), |
| 209 | + self.data_path.name, |
| 210 | + ) |
| 211 | + cache_dir.mkdir(parents=True, exist_ok=True) |
| 212 | + return cache_dir |
| 213 | + |
| 214 | + def _set_fit_global_state(self, num_positions: int) -> list[int]: |
| 215 | + # disable metadata tracking in MONAI for performance |
| 216 | + set_track_meta(False) |
| 217 | + # shuffle positions, randomness is handled globally |
| 218 | + return torch.randperm(num_positions).tolist() |
| 219 | + |
| 220 | + def _buffer_shape(self, arr_shape, fovs) -> tuple[int, ...]: |
| 221 | + return (len(fovs) * arr_shape[0], len(self.channels), *arr_shape[2:]) |
| 222 | + |
| 223 | + def setup(self, stage: Literal["fit", "validate"]) -> None: |
| 224 | + if stage not in ("fit", "validate"): |
| 225 | + raise NotImplementedError("Only fit and validate stages are supported.") |
| 226 | + plate: Plate = open_ome_zarr(self.data_path, mode="r", layout="hcs") |
| 227 | + positions = self._filter_fit_fovs(plate) |
| 228 | + arr_shape = positions[0][self.array_key].shape |
| 229 | + shuffled_indices = self._set_fit_global_state(len(positions)) |
| 230 | + num_train_fovs = int(len(positions) * self.split_ratio) |
| 231 | + train_fovs = [positions[i] for i in shuffled_indices[:num_train_fovs]] |
| 232 | + val_fovs = [positions[i] for i in shuffled_indices[num_train_fovs:]] |
| 233 | + _logger.debug(f"Training FOVs: {[p.zgroup.name for p in train_fovs]}") |
| 234 | + _logger.debug(f"Validation FOVs: {[p.zgroup.name for p in val_fovs]}") |
| 235 | + train_buffer = MemoryMappedTensor.empty( |
| 236 | + self._buffer_shape(arr_shape, train_fovs), |
| 237 | + dtype=torch.float32, |
| 238 | + filename=self.cache_dir / "train.mmap", |
| 239 | + ) |
| 240 | + val_buffer = MemoryMappedTensor.empty( |
| 241 | + self._buffer_shape(arr_shape, val_fovs), |
| 242 | + dtype=torch.float32, |
| 243 | + filename=self.cache_dir / "val.mmap", |
| 244 | + ) |
| 245 | + cache_map_train = Manager().dict() |
| 246 | + self.train_dataset = MmappedDataset( |
| 247 | + positions=train_fovs, |
| 248 | + channel_names=self.channels, |
| 249 | + cache_map=cache_map_train, |
| 250 | + buffer=train_buffer, |
| 251 | + preprocess_transforms=self.preprocessing_transforms, |
| 252 | + cpu_transform=self.train_cpu_transforms, |
| 253 | + array_key=self.array_key, |
| 254 | + ) |
| 255 | + cache_map_val = Manager().dict() |
| 256 | + self.val_dataset = MmappedDataset( |
| 257 | + positions=val_fovs, |
| 258 | + channel_names=self.channels, |
| 259 | + cache_map=cache_map_val, |
| 260 | + buffer=val_buffer, |
| 261 | + preprocess_transforms=self.preprocessing_transforms, |
| 262 | + cpu_transform=self.val_cpu_transforms, |
| 263 | + array_key=self.array_key, |
| 264 | + ) |
0 commit comments