3434 _request_dataloader ,
3535 _resolve_overfit_batches ,
3636)
37+
38+
39+
40+
41+
42+
43+
44+
45+
46+
47+
48+
49+
50+
51+
52+
3753from lightning .pytorch .trainer .connectors .logger_connector .result import _ResultCollection
3854from lightning .pytorch .trainer .states import RunningStage , TrainerFn
3955from lightning .pytorch .utilities .combined_loader import _SUPPORTED_MODES , CombinedLoader
@@ -211,28 +227,36 @@ def run(self) -> None:
211227 self .on_run_end ()
212228
213229 def setup_data (self ) -> None :
230+ """Sets up the training data loaders.
231+
232+ This method checks if the data loader needs to be reloaded based on the current epoch and the specified
233+ conditions. It initializes the combined loader for training and handles overfitting scenarios.
234+ """
214235 if self ._combined_loader is not None and not self ._should_reload_train_dl :
215- return
236+ return # No need to reload if already set up and not required
216237
217238 trainer = self .trainer
218239 pl_module = trainer .lightning_module
219240 if trainer .limit_train_batches == 0 or not is_overridden ("training_step" , pl_module ):
220- return
241+ return # Skip setup if no training batches or training step is not overridden
221242
222243 log .debug (f"{ self .__class__ .__name__ } : resetting train dataloader" )
223244
224245 source = self ._data_source
225246 train_dataloader = _request_dataloader (source )
226247 trainer .strategy .barrier ("train_dataloader()" )
227248
249+ # Initialize combined loader
228250 if not isinstance (train_dataloader , CombinedLoader ):
229251 combined_loader = CombinedLoader (train_dataloader , "max_size_cycle" )
230252 else :
231253 combined_loader = train_dataloader
232254
255+ # Handle overfitting scenarios
233256 if trainer .overfit_batches > 0 :
234257 _resolve_overfit_batches (combined_loader , mode = RunningStage .TRAINING )
235258
259+ # Process each data loader
236260 trainer_fn = TrainerFn .FITTING
237261 stage = RunningStage .TRAINING
238262 dataloaders = []
@@ -243,13 +267,14 @@ def setup_data(self) -> None:
243267 combined_loader .flattened = dataloaders
244268 self ._combined_loader = combined_loader
245269
270+ # Allow zero-length dataloaders if specified
246271 allow_zero_length = pl_module .allow_zero_length_dataloader_with_multiple_devices
247272 if trainer .datamodule is not None :
248273 allow_zero_length |= trainer .datamodule .allow_zero_length_dataloader_with_multiple_devices
249274
250275 limits = []
251276 for dl in combined_loader .flattened :
252- # determine number of batches
277+ # Determine number of batches
253278 length = len (dl ) if has_len_all_ranks (dl , trainer .strategy , allow_zero_length ) else float ("inf" )
254279 num_batches = _parse_num_batches (stage , length , trainer .limit_train_batches )
255280 limits .append (num_batches )
@@ -260,17 +285,18 @@ def setup_data(self) -> None:
260285
261286 self ._data_fetcher = _select_data_fetcher (trainer , RunningStage .TRAINING )
262287 self ._data_fetcher .setup (combined_loader )
263- iter (self ._data_fetcher ) # creates the iterator inside the fetcher
288+ iter (self ._data_fetcher ) # Creates the iterator inside the fetcher
264289 max_batches = sized_len (combined_loader )
265290 self .max_batches = max_batches if max_batches is not None else float ("inf" )
266291 has_len_all_ranks_ = has_len_all_ranks (combined_loader , trainer .strategy , allow_zero_length )
267292
268293 if self .max_batches == 0 :
269- return
294+ return # No batches to process
270295
271- # store epoch of dataloader reset for reload_dataloaders_every_n_epochs
296+ # Store epoch of dataloader reset for reload_dataloaders_every_n_epochs
272297 self ._last_train_dl_reload_epoch = trainer .current_epoch
273298
299+ # Validation check interval logic
274300 if isinstance (trainer .val_check_interval , int ):
275301 trainer .val_check_batch = trainer .val_check_interval
276302 if trainer .val_check_batch > self .max_batches and trainer .check_val_every_n_epoch is not None :
@@ -294,6 +320,7 @@ def setup_data(self) -> None:
294320 trainer .val_check_batch = int (self .max_batches * trainer .val_check_interval )
295321 trainer .val_check_batch = max (1 , trainer .val_check_batch )
296322
323+ # Warning for logging intervals
297324 if trainer .loggers and self .max_batches < trainer .log_every_n_steps and not trainer .fast_dev_run :
298325 rank_zero_warn (
299326 f"The number of training batches ({ self .max_batches } ) is smaller than the logging interval"
@@ -312,30 +339,30 @@ def reset(self) -> None:
312339
313340 def on_run_start (self ) -> None :
314341 """Calls the ``on_train_start`` hook."""
315- # update the current_epoch in- case of checkpoint reload
342+ # Update the current_epoch in case of checkpoint reload
316343 if not self ._iteration_based_training ():
317344 self .epoch_progress .current .completed = self .epoch_progress .current .processed
318345
319346 trainer = self .trainer
320347
321- # reload the evaluation dataloaders too for proper display in the progress bar
348+ # Reload the evaluation dataloaders for proper display in the progress bar
322349 if self .epoch_loop ._should_check_val_epoch () and trainer .val_dataloaders is None :
323350 trainer .validating = True
324- self .epoch_loop .val_loop .setup_data ()
351+ self .epoch_loop .val_loop .setup_data () # Setup validation data
325352 trainer .training = True
326353
327354 call ._call_callback_hooks (trainer , "on_train_start" )
328355 call ._call_lightning_module_hook (trainer , "on_train_start" )
329356 call ._call_strategy_hook (trainer , "on_train_start" )
330357
331358 def on_advance_start (self ) -> None :
332- """Prepares the dataloader for training and calls the hook ``on_train_epoch_start``"""
359+ """Prepares the dataloader for training and calls the hook ``on_train_epoch_start``. """
333360 trainer = self .trainer
334361
335- # might need to setup data again depending on `trainer.reload_dataloaders_every_n_epochs`
336- self .setup_data ()
362+ # Might need to setup data again depending on `trainer.reload_dataloaders_every_n_epochs`
363+ self .setup_data () # This ensures data is fresh for the current epoch
337364
338- # update the epoch value for all samplers
365+ # Update the epoch value for all samplers
339366 assert self ._combined_loader is not None
340367 for i , dl in enumerate (self ._combined_loader .flattened ):
341368 _set_sampler_epoch (dl , self .epoch_progress .current .processed )
0 commit comments