1
+ import bisect
2
+ import logging
3
+ from collections import defaultdict
1
4
from enum import Enum
2
5
from typing import Literal , Sequence
3
6
4
7
import torch
5
8
from lightning .pytorch import LightningDataModule
6
9
from lightning .pytorch .utilities .combined_loader import CombinedLoader
10
+ from monai .data import ThreadDataLoader
7
11
from torch .utils .data import ConcatDataset , DataLoader , Dataset
8
12
9
13
from viscy .data .distributed import ShardedDistributedSampler
10
14
from viscy .data .hcs import _collate_samples
11
15
16
+ _logger = logging .getLogger ("lightning.pytorch" )
17
+
12
18
13
19
class CombineMode (Enum ):
14
20
MIN_SIZE = "min_size"
@@ -82,6 +88,37 @@ def predict_dataloader(self):
82
88
)
83
89
84
90
91
+ class BatchedConcatDataset (ConcatDataset ):
92
+ def __getitem__ (self , idx ):
93
+ raise NotImplementedError
94
+
95
+ def _get_sample_indices (self , idx : int ) -> tuple [int , int ]:
96
+ if idx < 0 :
97
+ if - idx > len (self ):
98
+ raise ValueError (
99
+ "absolute value of index should not exceed dataset length"
100
+ )
101
+ idx = len (self ) + idx
102
+ dataset_idx = bisect .bisect_right (self .cumulative_sizes , idx )
103
+ if dataset_idx == 0 :
104
+ sample_idx = idx
105
+ else :
106
+ sample_idx = idx - self .cumulative_sizes [dataset_idx - 1 ]
107
+ return dataset_idx , sample_idx
108
+
109
+ def __getitems__ (self , indices : list [int ]) -> list :
110
+ grouped_indices = defaultdict (list )
111
+ for idx in indices :
112
+ dataset_idx , sample_indices = self ._get_sample_indices (idx )
113
+ grouped_indices [dataset_idx ].append (sample_indices )
114
+ _logger .debug (f"Grouped indices: { grouped_indices } " )
115
+ sub_batches = []
116
+ for dataset_idx , sample_indices in grouped_indices .items ():
117
+ sub_batch = self .datasets [dataset_idx ].__getitems__ (sample_indices )
118
+ sub_batches .extend (sub_batch )
119
+ return sub_batches
120
+
121
+
85
122
class ConcatDataModule (LightningDataModule ):
86
123
"""
87
124
Concatenate multiple data modules.
@@ -96,11 +133,16 @@ class ConcatDataModule(LightningDataModule):
96
133
Data modules to concatenate.
97
134
"""
98
135
136
+ _ConcatDataset = ConcatDataset
137
+
99
138
def __init__ (self , data_modules : Sequence [LightningDataModule ]):
100
139
super ().__init__ ()
101
140
self .data_modules = data_modules
102
141
self .num_workers = data_modules [0 ].num_workers
103
142
self .batch_size = data_modules [0 ].batch_size
143
+ self .persistent_workers = data_modules [0 ].persistent_workers
144
+ self .prefetch_factor = data_modules [0 ].prefetch_factor
145
+ self .pin_memory = data_modules [0 ].pin_memory
104
146
for dm in data_modules :
105
147
if dm .num_workers != self .num_workers :
106
148
raise ValueError ("Inconsistent number of workers" )
@@ -124,28 +166,62 @@ def setup(self, stage: Literal["fit", "validate", "test", "predict"]):
124
166
raise ValueError ("Inconsistent patches per stack" )
125
167
if stage != "fit" :
126
168
raise NotImplementedError ("Only fit stage is supported" )
127
- self .train_dataset = ConcatDataset (
169
+ self .train_dataset = self . _ConcatDataset (
128
170
[dm .train_dataset for dm in self .data_modules ]
129
171
)
130
- self .val_dataset = ConcatDataset ([dm .val_dataset for dm in self .data_modules ])
172
+ self .val_dataset = self ._ConcatDataset (
173
+ [dm .val_dataset for dm in self .data_modules ]
174
+ )
175
+
176
+ def _dataloader_kwargs (self ) -> dict :
177
+ return {
178
+ "num_workers" : self .num_workers ,
179
+ "persistent_workers" : self .persistent_workers ,
180
+ "prefetch_factor" : self .prefetch_factor if self .num_workers else None ,
181
+ "pin_memory" : self .pin_memory ,
182
+ }
131
183
132
184
def train_dataloader (self ):
133
185
return DataLoader (
134
186
self .train_dataset ,
135
- batch_size = self .batch_size // self .train_patches_per_stack ,
136
- num_workers = self .num_workers ,
137
187
shuffle = True ,
138
- persistent_workers = bool ( self .num_workers ) ,
188
+ batch_size = self .batch_size // self . train_patches_per_stack ,
139
189
collate_fn = _collate_samples ,
190
+ drop_last = True ,
191
+ ** self ._dataloader_kwargs (),
140
192
)
141
193
142
194
def val_dataloader (self ):
143
195
return DataLoader (
144
196
self .val_dataset ,
197
+ shuffle = False ,
198
+ batch_size = self .batch_size ,
199
+ drop_last = False ,
200
+ ** self ._dataloader_kwargs (),
201
+ )
202
+
203
+
204
+ class BatchedConcatDataModule (ConcatDataModule ):
205
+ _ConcatDataset = BatchedConcatDataset
206
+
207
+ def train_dataloader (self ):
208
+ return ThreadDataLoader (
209
+ self .train_dataset ,
210
+ use_thread_workers = True ,
211
+ batch_size = self .batch_size ,
212
+ shuffle = True ,
213
+ drop_last = True ,
214
+ ** self ._dataloader_kwargs (),
215
+ )
216
+
217
+ def val_dataloader (self ):
218
+ return ThreadDataLoader (
219
+ self .val_dataset ,
220
+ use_thread_workers = True ,
145
221
batch_size = self .batch_size ,
146
- num_workers = self .num_workers ,
147
222
shuffle = False ,
148
- persistent_workers = bool (self .num_workers ),
223
+ drop_last = False ,
224
+ ** self ._dataloader_kwargs (),
149
225
)
150
226
151
227
0 commit comments