Skip to content

Commit 89b2917

Browse files
ziw-liuedyoshikun
andauthored
Memory-mapped caching for image translation training (#218)
* caching dataloader * caching data module * black * ruff * Bump torch to 2.4.1 (#174) * update torch >2.4.1 * black * ruff * adding timeout to ram_dataloader * bandaid to cached dataloader * fixing the dataloader using torch collate_fn * replacing dictionary with single array * loading prior to epoch 0 * Revert "replacing dictionary with single array" This reverts commit 8c13f49. * using multiprocessing manager * add sharded distributed sampler * add example script for ddp caching * format and lint * addding the custom distrb sampler to hcs_ram.py * adding sampler to val train dataloader * fix divisibility of the last shard * hcs_ram format and lint * data module that only crops and does not collate * wip: execute transforms on the GPU * path for if not ddp * fix randomness in inversion transform * add option to pop the normalization metadata * move gpu transform definition back to data module * add tiled crop transform for validation * add stack channel transform for gpu augmentation * fix typing * collate before sending to gpu * inherit gpu transforms for livecell dataset * update fcmae engine to apply per-dataset augmentations * format and lint hcs_ram * fix abc type hint * update docstring style * disable grad for validation transforms * improve sample image logging in fcmae * fix dataset length when batch size is larger than the dataset * fix docstring * add option to disable normalization metadata * inherit gpu transform for ctmc * remove duplicate method overrride * update docstring for ctmc * allow skipping caching for large datasets * make the fcmae module compatible with image translation * remove prototype implementation * fix import path * Arbitrary prediction time transforms (#209) * fix spelling in docstring and comment * add batched zoom transform for tta * add standalone lightning module for arbitrary TTA * fix composition of different zoom factors * add docstrings * wip: segmentation module * avoid casting * update import path from iohub * make integer array in fixture * labels fixture * test segmentation metrics modules * less strings * test non-empty * select which wells to include in fit #205 * make well selection a mixin * wip: mmap cache data module * support exclusion of FOVs * wip: precompute normalization * add augmentations benchmark * fix cpu threads default * fix probability (affects cpu results) * disable metadata tracking * fix non-distributed initialization * refactor transforms into submodules * do not import type hints at runtime * update docstring * backwards compatible import path * fix annotations * fix style * fix dice score import * fix dice score parameters * apply formatting to exercise * fix labels data type * fix labels input shape --------- Co-authored-by: Eduardo Hirata-Miyasaki <[email protected]>
1 parent 42a0228 commit 89b2917

File tree

15 files changed

+1714
-2270
lines changed

15 files changed

+1714
-2270
lines changed

examples/virtual_staining/dlmbl_exercise/exercise.ipynb

Lines changed: 270 additions & 184 deletions
Large diffs are not rendered by default.

examples/virtual_staining/dlmbl_exercise/solution.ipynb

Lines changed: 346 additions & 1690 deletions
Large diffs are not rendered by default.

examples/virtual_staining/dlmbl_exercise/solution.py

Lines changed: 370 additions & 239 deletions
Large diffs are not rendered by default.

viscy/cli.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ def subcommands() -> dict[str, set[str]]:
2121
subcommand_base_args = {"model"}
2222
subcommands["preprocess"] = subcommand_base_args
2323
subcommands["export"] = subcommand_base_args
24+
subcommands["precompute"] = subcommand_base_args
2425
return subcommands
2526

2627
def add_arguments_to_parser(self, parser) -> None:
@@ -50,8 +51,8 @@ def main() -> None:
5051
Set default random seed to 42.
5152
"""
5253
_setup_environment()
53-
require_model = "preprocess" not in sys.argv
54-
require_data = {"preprocess", "export"}.isdisjoint(sys.argv)
54+
require_model = {"preprocess", "precompute"}.isdisjoint(sys.argv)
55+
require_data = {"preprocess", "precompute", "export"}.isdisjoint(sys.argv)
5556
_ = VisCyCLI(
5657
model_class=LightningModule,
5758
datamodule_class=LightningDataModule if require_data else None,

viscy/data/gpu_aug.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from viscy.data.distributed import ShardedDistributedSampler
2020
from viscy.data.hcs import _ensure_channel_list, _read_norm_meta
2121
from viscy.data.typing import DictTransform, NormMeta
22+
from viscy.preprocessing.precompute import _filter_fovs, _filter_wells
2223

2324
if TYPE_CHECKING:
2425
from multiprocessing.managers import DictProxy
@@ -36,6 +37,7 @@ class GPUTransformDataModule(ABC, LightningDataModule):
3637
batch_size: int
3738
num_workers: int
3839
pin_memory: bool
40+
prefetch_factor: int | None
3941

4042
def _maybe_sampler(
4143
self, dataset: Dataset, shuffle: bool
@@ -59,6 +61,7 @@ def train_dataloader(self) -> DataLoader:
5961
pin_memory=self.pin_memory,
6062
drop_last=False,
6163
collate_fn=list_data_collate,
64+
prefetch_factor=self.prefetch_factor,
6265
)
6366

6467
def val_dataloader(self) -> DataLoader:
@@ -74,6 +77,7 @@ def val_dataloader(self) -> DataLoader:
7477
pin_memory=self.pin_memory,
7578
drop_last=False,
7679
collate_fn=list_data_collate,
80+
prefetch_factor=self.prefetch_factor,
7781
)
7882

7983
@property
@@ -169,7 +173,23 @@ def __getitem__(self, idx: int) -> dict[str, Tensor]:
169173
return sample
170174

171175

172-
class CachedOmeZarrDataModule(GPUTransformDataModule):
176+
class SelectWell:
177+
_include_wells: list[str] | None
178+
_exclude_fovs: list[str] | None
179+
180+
def _filter_fit_fovs(self, plate: Plate) -> list[Position]:
181+
positions = []
182+
for well in _filter_wells(plate, include_wells=self._include_wells):
183+
for fov in _filter_fovs(well, exclude_fovs=self._exclude_fovs):
184+
positions.append(fov)
185+
if len(positions) < 2:
186+
raise ValueError(
187+
"At least 2 FOVs are required for training and validation."
188+
)
189+
return positions
190+
191+
192+
class CachedOmeZarrDataModule(GPUTransformDataModule, SelectWell):
173193
"""Data module for cached OME-Zarr arrays.
174194
175195
Parameters
@@ -199,6 +219,8 @@ class CachedOmeZarrDataModule(GPUTransformDataModule):
199219
Skip caching for this dataset, by default False
200220
include_wells : list[str], optional
201221
List of well names to include in the dataset, by default None (all)
222+
include_wells : list[str], optional
223+
List of well names to include in the dataset, by default None (all)
202224
"""
203225

204226
def __init__(
@@ -215,6 +237,7 @@ def __init__(
215237
pin_memory: bool = True,
216238
skip_cache: bool = False,
217239
include_wells: list[str] | None = None,
240+
exclude_fovs: list[str] | None = None,
218241
):
219242
super().__init__()
220243
self.data_path = data_path
@@ -229,6 +252,7 @@ def __init__(
229252
self.pin_memory = pin_memory
230253
self.skip_cache = skip_cache
231254
self._include_wells = include_wells
255+
self._exclude_fovs = exclude_fovs
232256

233257
@property
234258
def train_cpu_transforms(self) -> Compose:

viscy/data/mmap_cache.py

Lines changed: 264 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,264 @@
1+
from __future__ import annotations
2+
3+
import os
4+
import tempfile
5+
from logging import getLogger
6+
from pathlib import Path
7+
from typing import TYPE_CHECKING, Literal
8+
9+
import numpy as np
10+
import torch
11+
from iohub.ngff import Plate, Position, open_ome_zarr
12+
from monai.data.meta_obj import set_track_meta
13+
from monai.transforms.compose import Compose
14+
from tensordict.memmap import MemoryMappedTensor
15+
from torch import Tensor
16+
from torch.multiprocessing import Manager
17+
from torch.utils.data import Dataset
18+
19+
from viscy.data.gpu_aug import GPUTransformDataModule, SelectWell
20+
from viscy.data.hcs import _ensure_channel_list, _read_norm_meta
21+
from viscy.data.typing import DictTransform, NormMeta
22+
23+
if TYPE_CHECKING:
24+
from multiprocessing.managers import DictProxy
25+
26+
_logger = getLogger("lightning.pytorch")
27+
28+
_CacheMetadata = tuple[Position, int, NormMeta | None]
29+
30+
31+
class MmappedDataset(Dataset):
32+
def __init__(
33+
self,
34+
positions: list[Position],
35+
channel_names: list[str],
36+
cache_map: DictProxy,
37+
buffer: MemoryMappedTensor,
38+
preprocess_transforms: Compose | None = None,
39+
cpu_transform: Compose | None = None,
40+
array_key: str = "0",
41+
load_normalization_metadata: bool = True,
42+
):
43+
key = 0
44+
self._metadata_map: dict[int, _CacheMetadata] = {}
45+
for position in positions:
46+
img = position[array_key]
47+
norm_meta = _read_norm_meta(position)
48+
for time_idx in range(img.frames):
49+
cache_map[key] = None
50+
self._metadata_map[key] = (position, time_idx, norm_meta)
51+
key += 1
52+
self.channels = {ch: position.get_channel_index(ch) for ch in channel_names}
53+
self.array_key = array_key
54+
self._buffer = buffer
55+
self._cache_map = cache_map
56+
self.preprocess_transforms = preprocess_transforms
57+
self.cpu_transform = cpu_transform
58+
self.load_normalization_metadata = load_normalization_metadata
59+
60+
def __len__(self) -> int:
61+
return len(self._metadata_map)
62+
63+
def _split_channels(self, volume: Tensor) -> dict[str, Tensor]:
64+
return {name: img[None] for name, img in zip(self.channels.keys(), volume)}
65+
66+
def _preprocess_volume(self, volume: Tensor, norm_meta) -> Tensor:
67+
if self.preprocess_transforms:
68+
orig_shape = volume.shape
69+
sample = self._split_channels(volume)
70+
if self.load_normalization_metadata:
71+
sample["norm_meta"] = norm_meta
72+
sample = self.preprocess_transforms(sample)
73+
volume = torch.cat([sample[name] for name in self.channels.keys()], dim=0)
74+
assert volume.shape == orig_shape, (volume.shape, orig_shape, sample.keys())
75+
return volume
76+
77+
def __getitem__(self, idx: int) -> dict[str, Tensor]:
78+
position, time_idx, norm_meta = self._metadata_map[idx]
79+
if not self._cache_map[idx]:
80+
_logger.debug(f"Loading volume for index {idx}")
81+
volume = torch.from_numpy(
82+
position[self.array_key]
83+
.oindex[time_idx, list(self.channels.values())]
84+
.astype(np.float32)
85+
)
86+
volume = self._preprocess_volume(volume, norm_meta)
87+
_logger.debug(f"Caching for index {idx}")
88+
self._cache_map[idx] = True
89+
self._buffer[idx] = volume
90+
else:
91+
_logger.debug(f"Using cached volume for index {idx}")
92+
volume = self._buffer[idx]
93+
sample = self._split_channels(volume)
94+
if self.cpu_transform:
95+
sample = self.cpu_transform(sample)
96+
if not isinstance(sample, list):
97+
sample = [sample]
98+
return sample
99+
100+
101+
class MmappedDataModule(GPUTransformDataModule, SelectWell):
102+
"""Data module for cached OME-Zarr arrays.
103+
104+
Parameters
105+
----------
106+
data_path : Path
107+
Path to the HCS OME-Zarr dataset.
108+
channels : str | list[str]
109+
Channel names to load.
110+
batch_size : int
111+
Batch size for training and validation.
112+
num_workers : int
113+
Number of workers for data-loaders.
114+
split_ratio : float
115+
Fraction of the FOVs used for the training split.
116+
The rest will be used for validation.
117+
train_cpu_transforms : list[DictTransform]
118+
Transforms to be applied on the CPU during training.
119+
val_cpu_transforms : list[DictTransform]
120+
Transforms to be applied on the CPU during validation.
121+
train_gpu_transforms : list[DictTransform]
122+
Transforms to be applied on the GPU during training.
123+
val_gpu_transforms : list[DictTransform]
124+
Transforms to be applied on the GPU during validation.
125+
pin_memory : bool, optional
126+
Use page-locked memory in data-loaders, by default True
127+
prefetch_factor : int | None, optional
128+
Prefetching ratio for the torch dataloader, by default None
129+
array_key : str, optional
130+
Name of the image arrays (multiscales level), by default "0"
131+
scratch_dir : Path | None, optional
132+
Path to the scratch directory,
133+
by default None (use OS temporary data directory)
134+
include_wells : list[str] | None, optional
135+
Include only a subset of wells, by default None (include all wells)
136+
exclude_fovs : list[str] | None, optional
137+
Exclude FOVs, by default None (do not exclude any FOVs)
138+
"""
139+
140+
def __init__(
141+
self,
142+
data_path: Path,
143+
channels: str | list[str],
144+
batch_size: int,
145+
num_workers: int,
146+
split_ratio: float,
147+
preprocess_transforms: list[DictTransform],
148+
train_cpu_transforms: list[DictTransform],
149+
val_cpu_transforms: list[DictTransform],
150+
train_gpu_transforms: list[DictTransform],
151+
val_gpu_transforms: list[DictTransform],
152+
pin_memory: bool = True,
153+
prefetch_factor: int | None = None,
154+
array_key: str = "0",
155+
scratch_dir: Path | None = None,
156+
include_wells: list[str] | None = None,
157+
exclude_fovs: list[str] | None = None,
158+
):
159+
super().__init__()
160+
self.data_path = Path(data_path)
161+
self.channels = _ensure_channel_list(channels)
162+
self.batch_size = batch_size
163+
self.num_workers = num_workers
164+
self.split_ratio = split_ratio
165+
self._preprocessing_transforms = Compose(preprocess_transforms)
166+
self._train_cpu_transforms = Compose(train_cpu_transforms)
167+
self._val_cpu_transforms = Compose(val_cpu_transforms)
168+
self._train_gpu_transforms = Compose(train_gpu_transforms)
169+
self._val_gpu_transforms = Compose(val_gpu_transforms)
170+
self.pin_memory = pin_memory
171+
self.array_key = array_key
172+
self.scratch_dir = scratch_dir
173+
self._include_wells = include_wells
174+
self._exclude_fovs = exclude_fovs
175+
self.prepare_data_per_node = True
176+
self.prefetch_factor = prefetch_factor if self.num_workers > 0 else None
177+
178+
@property
179+
def preprocessing_transforms(self) -> Compose:
180+
return self._preprocessing_transforms
181+
182+
@property
183+
def train_cpu_transforms(self) -> Compose:
184+
return self._train_cpu_transforms
185+
186+
@property
187+
def train_gpu_transforms(self) -> Compose:
188+
return self._train_gpu_transforms
189+
190+
@property
191+
def val_cpu_transforms(self) -> Compose:
192+
return self._val_cpu_transforms
193+
194+
@property
195+
def val_gpu_transforms(self) -> Compose:
196+
return self._val_gpu_transforms
197+
198+
@property
199+
def cache_dir(self) -> Path:
200+
scratch_dir = self.scratch_dir or Path(tempfile.gettempdir())
201+
cache_dir = Path(
202+
scratch_dir,
203+
os.getenv("SLURM_JOB_ID", "viscy_cache"),
204+
str(
205+
torch.distributed.get_rank()
206+
if torch.distributed.is_initialized()
207+
else 0
208+
),
209+
self.data_path.name,
210+
)
211+
cache_dir.mkdir(parents=True, exist_ok=True)
212+
return cache_dir
213+
214+
def _set_fit_global_state(self, num_positions: int) -> list[int]:
215+
# disable metadata tracking in MONAI for performance
216+
set_track_meta(False)
217+
# shuffle positions, randomness is handled globally
218+
return torch.randperm(num_positions).tolist()
219+
220+
def _buffer_shape(self, arr_shape, fovs) -> tuple[int, ...]:
221+
return (len(fovs) * arr_shape[0], len(self.channels), *arr_shape[2:])
222+
223+
def setup(self, stage: Literal["fit", "validate"]) -> None:
224+
if stage not in ("fit", "validate"):
225+
raise NotImplementedError("Only fit and validate stages are supported.")
226+
plate: Plate = open_ome_zarr(self.data_path, mode="r", layout="hcs")
227+
positions = self._filter_fit_fovs(plate)
228+
arr_shape = positions[0][self.array_key].shape
229+
shuffled_indices = self._set_fit_global_state(len(positions))
230+
num_train_fovs = int(len(positions) * self.split_ratio)
231+
train_fovs = [positions[i] for i in shuffled_indices[:num_train_fovs]]
232+
val_fovs = [positions[i] for i in shuffled_indices[num_train_fovs:]]
233+
_logger.debug(f"Training FOVs: {[p.zgroup.name for p in train_fovs]}")
234+
_logger.debug(f"Validation FOVs: {[p.zgroup.name for p in val_fovs]}")
235+
train_buffer = MemoryMappedTensor.empty(
236+
self._buffer_shape(arr_shape, train_fovs),
237+
dtype=torch.float32,
238+
filename=self.cache_dir / "train.mmap",
239+
)
240+
val_buffer = MemoryMappedTensor.empty(
241+
self._buffer_shape(arr_shape, val_fovs),
242+
dtype=torch.float32,
243+
filename=self.cache_dir / "val.mmap",
244+
)
245+
cache_map_train = Manager().dict()
246+
self.train_dataset = MmappedDataset(
247+
positions=train_fovs,
248+
channel_names=self.channels,
249+
cache_map=cache_map_train,
250+
buffer=train_buffer,
251+
preprocess_transforms=self.preprocessing_transforms,
252+
cpu_transform=self.train_cpu_transforms,
253+
array_key=self.array_key,
254+
)
255+
cache_map_val = Manager().dict()
256+
self.val_dataset = MmappedDataset(
257+
positions=val_fovs,
258+
channel_names=self.channels,
259+
cache_map=cache_map_val,
260+
buffer=val_buffer,
261+
preprocess_transforms=self.preprocessing_transforms,
262+
cpu_transform=self.val_cpu_transforms,
263+
array_key=self.array_key,
264+
)

0 commit comments

Comments
 (0)