Skip to content

Commit a2186d6

Browse files
ziw-liuedyoshikun
andauthored
Batched dataloading for contrastive learning (#276)
* add tensorstore and kornia dependencies * use __getitems__ for batched dataloading * implement accelerated transforms * update profiling script * remove profilehooks dependency * export new transforms * add batched version of ConcatDataModule * add debug logging * collate normalization metadata * do not require index * match statistics with image * fix gpu transforms for ddp * add the decollated and todeviced transforms * style * preserve query ordering * implement predict dataloader hook * revert accidental diff * restore moduel execution for debugging * fix tuple index --------- Co-authored-by: Eduardo Hirata-Miyasaki <[email protected]>
1 parent ac7ac63 commit a2186d6

File tree

9 files changed

+596
-90
lines changed

9 files changed

+596
-90
lines changed

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@ classifiers = [
1616
"Programming Language :: Python :: 3.13",
1717
]
1818
dependencies = [
19-
"iohub>=0.2.0b0",
19+
"iohub[tensorstore]>=0.2.2rc0",
20+
"kornia",
2021
"torch>=2.4.1",
2122
"timm>=0.9.5",
2223
"tensorboard>=2.13.0",
@@ -61,7 +62,6 @@ dev = [
6162
"pytest-cov",
6263
"hypothesis",
6364
"ruff",
64-
"profilehooks",
6565
"onnxruntime",
6666
]
6767

viscy/cli.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,3 +64,7 @@ def main() -> None:
6464
"description": "Computer vision models for single-cell phenotyping."
6565
},
6666
)
67+
68+
69+
if __name__ == "__main__":
70+
main()

viscy/data/combined.py

Lines changed: 83 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,20 @@
1+
import bisect
2+
import logging
3+
from collections import defaultdict
14
from enum import Enum
25
from typing import Literal, Sequence
36

47
import torch
58
from lightning.pytorch import LightningDataModule
69
from lightning.pytorch.utilities.combined_loader import CombinedLoader
10+
from monai.data import ThreadDataLoader
711
from torch.utils.data import ConcatDataset, DataLoader, Dataset
812

913
from viscy.data.distributed import ShardedDistributedSampler
1014
from viscy.data.hcs import _collate_samples
1115

16+
_logger = logging.getLogger("lightning.pytorch")
17+
1218

1319
class CombineMode(Enum):
1420
MIN_SIZE = "min_size"
@@ -82,6 +88,37 @@ def predict_dataloader(self):
8288
)
8389

8490

91+
class BatchedConcatDataset(ConcatDataset):
92+
def __getitem__(self, idx):
93+
raise NotImplementedError
94+
95+
def _get_sample_indices(self, idx: int) -> tuple[int, int]:
96+
if idx < 0:
97+
if -idx > len(self):
98+
raise ValueError(
99+
"absolute value of index should not exceed dataset length"
100+
)
101+
idx = len(self) + idx
102+
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
103+
if dataset_idx == 0:
104+
sample_idx = idx
105+
else:
106+
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
107+
return dataset_idx, sample_idx
108+
109+
def __getitems__(self, indices: list[int]) -> list:
110+
grouped_indices = defaultdict(list)
111+
for idx in indices:
112+
dataset_idx, sample_indices = self._get_sample_indices(idx)
113+
grouped_indices[dataset_idx].append(sample_indices)
114+
_logger.debug(f"Grouped indices: {grouped_indices}")
115+
sub_batches = []
116+
for dataset_idx, sample_indices in grouped_indices.items():
117+
sub_batch = self.datasets[dataset_idx].__getitems__(sample_indices)
118+
sub_batches.extend(sub_batch)
119+
return sub_batches
120+
121+
85122
class ConcatDataModule(LightningDataModule):
86123
"""
87124
Concatenate multiple data modules.
@@ -96,11 +133,16 @@ class ConcatDataModule(LightningDataModule):
96133
Data modules to concatenate.
97134
"""
98135

136+
_ConcatDataset = ConcatDataset
137+
99138
def __init__(self, data_modules: Sequence[LightningDataModule]):
100139
super().__init__()
101140
self.data_modules = data_modules
102141
self.num_workers = data_modules[0].num_workers
103142
self.batch_size = data_modules[0].batch_size
143+
self.persistent_workers = data_modules[0].persistent_workers
144+
self.prefetch_factor = data_modules[0].prefetch_factor
145+
self.pin_memory = data_modules[0].pin_memory
104146
for dm in data_modules:
105147
if dm.num_workers != self.num_workers:
106148
raise ValueError("Inconsistent number of workers")
@@ -124,28 +166,62 @@ def setup(self, stage: Literal["fit", "validate", "test", "predict"]):
124166
raise ValueError("Inconsistent patches per stack")
125167
if stage != "fit":
126168
raise NotImplementedError("Only fit stage is supported")
127-
self.train_dataset = ConcatDataset(
169+
self.train_dataset = self._ConcatDataset(
128170
[dm.train_dataset for dm in self.data_modules]
129171
)
130-
self.val_dataset = ConcatDataset([dm.val_dataset for dm in self.data_modules])
172+
self.val_dataset = self._ConcatDataset(
173+
[dm.val_dataset for dm in self.data_modules]
174+
)
175+
176+
def _dataloader_kwargs(self) -> dict:
177+
return {
178+
"num_workers": self.num_workers,
179+
"persistent_workers": self.persistent_workers,
180+
"prefetch_factor": self.prefetch_factor if self.num_workers else None,
181+
"pin_memory": self.pin_memory,
182+
}
131183

132184
def train_dataloader(self):
133185
return DataLoader(
134186
self.train_dataset,
135-
batch_size=self.batch_size // self.train_patches_per_stack,
136-
num_workers=self.num_workers,
137187
shuffle=True,
138-
persistent_workers=bool(self.num_workers),
188+
batch_size=self.batch_size // self.train_patches_per_stack,
139189
collate_fn=_collate_samples,
190+
drop_last=True,
191+
**self._dataloader_kwargs(),
140192
)
141193

142194
def val_dataloader(self):
143195
return DataLoader(
144196
self.val_dataset,
197+
shuffle=False,
198+
batch_size=self.batch_size,
199+
drop_last=False,
200+
**self._dataloader_kwargs(),
201+
)
202+
203+
204+
class BatchedConcatDataModule(ConcatDataModule):
205+
_ConcatDataset = BatchedConcatDataset
206+
207+
def train_dataloader(self):
208+
return ThreadDataLoader(
209+
self.train_dataset,
210+
use_thread_workers=True,
211+
batch_size=self.batch_size,
212+
shuffle=True,
213+
drop_last=True,
214+
**self._dataloader_kwargs(),
215+
)
216+
217+
def val_dataloader(self):
218+
return ThreadDataLoader(
219+
self.val_dataset,
220+
use_thread_workers=True,
145221
batch_size=self.batch_size,
146-
num_workers=self.num_workers,
147222
shuffle=False,
148-
persistent_workers=bool(self.num_workers),
223+
drop_last=False,
224+
**self._dataloader_kwargs(),
149225
)
150226

151227

0 commit comments

Comments
 (0)