diff --git a/src/lightning/pytorch/loops/fit_loop.py b/src/lightning/pytorch/loops/fit_loop.py index eb30e32757c9a..6ccf752026914 100644 --- a/src/lightning/pytorch/loops/fit_loop.py +++ b/src/lightning/pytorch/loops/fit_loop.py @@ -211,13 +211,19 @@ def run(self) -> None: self.on_run_end() def setup_data(self) -> None: + """Sets up the training data loaders. + + This method checks if the data loader needs to be reloaded based on the current epoch and the specified + conditions. It initializes the combined loader for training and handles overfitting scenarios. + + """ if self._combined_loader is not None and not self._should_reload_train_dl: - return + return # No need to reload if already set up and not required trainer = self.trainer pl_module = trainer.lightning_module if trainer.limit_train_batches == 0 or not is_overridden("training_step", pl_module): - return + return # Skip setup if no training batches or training step is not overridden log.debug(f"{self.__class__.__name__}: resetting train dataloader") @@ -225,14 +231,17 @@ def setup_data(self) -> None: train_dataloader = _request_dataloader(source) trainer.strategy.barrier("train_dataloader()") + # Initialize combined loader if not isinstance(train_dataloader, CombinedLoader): combined_loader = CombinedLoader(train_dataloader, "max_size_cycle") else: combined_loader = train_dataloader + # Handle overfitting scenarios if trainer.overfit_batches > 0: _resolve_overfit_batches(combined_loader, mode=RunningStage.TRAINING) + # Process each data loader trainer_fn = TrainerFn.FITTING stage = RunningStage.TRAINING dataloaders = [] @@ -243,13 +252,14 @@ def setup_data(self) -> None: combined_loader.flattened = dataloaders self._combined_loader = combined_loader + # Allow zero-length dataloaders if specified allow_zero_length = pl_module.allow_zero_length_dataloader_with_multiple_devices if trainer.datamodule is not None: allow_zero_length |= trainer.datamodule.allow_zero_length_dataloader_with_multiple_devices limits = [] for dl in combined_loader.flattened: - # determine number of batches + # Determine number of batches length = len(dl) if has_len_all_ranks(dl, trainer.strategy, allow_zero_length) else float("inf") num_batches = _parse_num_batches(stage, length, trainer.limit_train_batches) limits.append(num_batches) @@ -260,17 +270,18 @@ def setup_data(self) -> None: self._data_fetcher = _select_data_fetcher(trainer, RunningStage.TRAINING) self._data_fetcher.setup(combined_loader) - iter(self._data_fetcher) # creates the iterator inside the fetcher + iter(self._data_fetcher) # Creates the iterator inside the fetcher max_batches = sized_len(combined_loader) self.max_batches = max_batches if max_batches is not None else float("inf") has_len_all_ranks_ = has_len_all_ranks(combined_loader, trainer.strategy, allow_zero_length) if self.max_batches == 0: - return + return # No batches to process - # store epoch of dataloader reset for reload_dataloaders_every_n_epochs + # Store epoch of dataloader reset for reload_dataloaders_every_n_epochs self._last_train_dl_reload_epoch = trainer.current_epoch + # Validation check interval logic if isinstance(trainer.val_check_interval, int): trainer.val_check_batch = trainer.val_check_interval if trainer.val_check_batch > self.max_batches and trainer.check_val_every_n_epoch is not None: @@ -294,6 +305,7 @@ def setup_data(self) -> None: trainer.val_check_batch = int(self.max_batches * trainer.val_check_interval) trainer.val_check_batch = max(1, trainer.val_check_batch) + # Warning for logging intervals if trainer.loggers and self.max_batches < trainer.log_every_n_steps and not trainer.fast_dev_run: rank_zero_warn( f"The number of training batches ({self.max_batches}) is smaller than the logging interval" @@ -312,16 +324,16 @@ def reset(self) -> None: def on_run_start(self) -> None: """Calls the ``on_train_start`` hook.""" - # update the current_epoch in-case of checkpoint reload + # Update the current_epoch in case of checkpoint reload if not self._iteration_based_training(): self.epoch_progress.current.completed = self.epoch_progress.current.processed trainer = self.trainer - # reload the evaluation dataloaders too for proper display in the progress bar + # Reload the evaluation dataloaders for proper display in the progress bar if self.epoch_loop._should_check_val_epoch() and trainer.val_dataloaders is None: trainer.validating = True - self.epoch_loop.val_loop.setup_data() + self.epoch_loop.val_loop.setup_data() # Setup validation data trainer.training = True call._call_callback_hooks(trainer, "on_train_start") @@ -329,13 +341,13 @@ def on_run_start(self) -> None: call._call_strategy_hook(trainer, "on_train_start") def on_advance_start(self) -> None: - """Prepares the dataloader for training and calls the hook ``on_train_epoch_start``""" + """Prepares the dataloader for training and calls the hook ``on_train_epoch_start``.""" trainer = self.trainer - # might need to setup data again depending on `trainer.reload_dataloaders_every_n_epochs` - self.setup_data() + # Might need to setup data again depending on `trainer.reload_dataloaders_every_n_epochs` + self.setup_data() # This ensures data is fresh for the current epoch - # update the epoch value for all samplers + # Update the epoch value for all samplers assert self._combined_loader is not None for i, dl in enumerate(self._combined_loader.flattened): _set_sampler_epoch(dl, self.epoch_progress.current.processed)