@@ -181,46 +181,48 @@ def print_pass(*args, **kwargs):
181181
182182 eval_set = "Val"
183183
184- # Dataloader --------------------------------------------------------------
185- dl_train_kwargs = {
186- "batch_size" : config .batch_size_per_gpu ,
187- "drop_last" : True ,
188- "sampler" : None ,
189- "shuffle" : True ,
190- "worker_init_fn" : utils .worker_seed_fn ,
191- }
192- dl_val_kwargs = {
193- "batch_size" : config .batch_size_per_gpu ,
194- "drop_last" : False ,
195- "sampler" : None ,
196- "shuffle" : False ,
197- "worker_init_fn" : utils .worker_seed_fn ,
198- }
199- if config .cpu_workers is None :
200- config .cpu_workers = utils .get_num_cpu_available ()
201- if use_cuda :
202- cuda_kwargs = {"num_workers" : config .cpu_workers , "pin_memory" : True }
203- dl_train_kwargs .update (cuda_kwargs )
204- dl_val_kwargs .update (cuda_kwargs )
205-
206- if config .distributed :
207- # The DistributedSampler breaks up the dataset across the GPUs
208- dl_train_kwargs ["sampler" ] = DistributedSampler (
209- dataset_train ,
210- shuffle = True ,
211- seed = config .seed if config .seed is not None else 0 ,
212- drop_last = False ,
213- )
214- dl_train_kwargs ["shuffle" ] = None
215- dl_val_kwargs ["sampler" ] = DistributedSampler (
216- dataset_val ,
217- shuffle = False ,
218- drop_last = False ,
219- )
220- dl_val_kwargs ["shuffle" ] = None
221-
222- dataloader_train = torch .utils .data .DataLoader (dataset_train , ** dl_train_kwargs )
223- dataloader_val = torch .utils .data .DataLoader (dataset_val , ** dl_val_kwargs )
184+ # Dataloaders -------------------------------------------------------------
185+ if config .lazy_load :
186+ # streaming IterableDataset → no sampler, no shuffle
187+ stream_kwargs = {
188+ "batch_size" : config .batch_size_per_gpu ,
189+ "drop_last" : True ,
190+ "num_workers" : config .cpu_workers ,
191+ "pin_memory" : use_cuda ,
192+ "worker_init_fn" : utils .worker_seed_fn ,
193+ }
194+ dataloader_train = torch .utils .data .DataLoader (dataset_train , ** stream_kwargs )
195+ dataloader_val = torch .utils .data .DataLoader (dataset_val , ** stream_kwargs )
196+ else :
197+ # map‑style Dataset → use DistributedSampler in dist. mode
198+ map_train_kwargs = {
199+ "batch_size" : config .batch_size_per_gpu ,
200+ "drop_last" : True ,
201+ "shuffle" : True ,
202+ "num_workers" : config .cpu_workers ,
203+ "pin_memory" : use_cuda ,
204+ "worker_init_fn" : utils .worker_seed_fn ,
205+ }
206+ map_val_kwargs = {
207+ "batch_size" : config .batch_size_per_gpu ,
208+ "drop_last" : False ,
209+ "shuffle" : False ,
210+ "num_workers" : config .cpu_workers ,
211+ "pin_memory" : use_cuda ,
212+ "worker_init_fn" : utils .worker_seed_fn ,
213+ }
214+ if config .distributed :
215+ map_train_kwargs ["shuffle" ] = False
216+ map_train_kwargs ["sampler" ] = DistributedSampler (
217+ dataset_train , shuffle = True ,
218+ seed = (config .seed or 0 ),
219+ drop_last = False ,
220+ )
221+ map_val_kwargs ["sampler" ] = DistributedSampler (
222+ dataset_val , shuffle = False , drop_last = False
223+ )
224+ dataloader_train = torch .utils .data .DataLoader (dataset_train , ** map_train_kwargs )
225+ dataloader_val = torch .utils .data .DataLoader (dataset_val , ** map_val_kwargs )
224226
225227 # MODEL ===================================================================
226228 base_pairs = "ACGT"
0 commit comments