Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 25 additions & 13 deletions src/lightning/pytorch/loops/fit_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,28 +211,37 @@ def run(self) -> None:
self.on_run_end()

def setup_data(self) -> None:
"""Sets up the training data loaders.

This method checks if the data loader needs to be reloaded based on the current epoch and the specified
conditions. It initializes the combined loader for training and handles overfitting scenarios.

"""
if self._combined_loader is not None and not self._should_reload_train_dl:
return
return # No need to reload if already set up and not required

trainer = self.trainer
pl_module = trainer.lightning_module
if trainer.limit_train_batches == 0 or not is_overridden("training_step", pl_module):
return
return # Skip setup if no training batches or training step is not overridden

log.debug(f"{self.__class__.__name__}: resetting train dataloader")

source = self._data_source
train_dataloader = _request_dataloader(source)
trainer.strategy.barrier("train_dataloader()")

# Initialize combined loader
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs to be extended: from now on we assume we deal with a combined_loader, in case of a single dataloader we treat it as a combined_loader holding a single data loader.

if not isinstance(train_dataloader, CombinedLoader):
combined_loader = CombinedLoader(train_dataloader, "max_size_cycle")
else:
combined_loader = train_dataloader

# Handle overfitting scenarios
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not correct. Overfit batches is a feature of Trainer to purposefully overfit to spot potential bugs in users code.

if trainer.overfit_batches > 0:
_resolve_overfit_batches(combined_loader, mode=RunningStage.TRAINING)

# Process each data loader
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the comment to be useful we need to expand what "Process" means

trainer_fn = TrainerFn.FITTING
stage = RunningStage.TRAINING
dataloaders = []
Expand All @@ -243,13 +252,14 @@ def setup_data(self) -> None:
combined_loader.flattened = dataloaders
self._combined_loader = combined_loader

# Allow zero-length dataloaders if specified
allow_zero_length = pl_module.allow_zero_length_dataloader_with_multiple_devices
if trainer.datamodule is not None:
allow_zero_length |= trainer.datamodule.allow_zero_length_dataloader_with_multiple_devices

limits = []
for dl in combined_loader.flattened:
# determine number of batches
# Determine number of batches
length = len(dl) if has_len_all_ranks(dl, trainer.strategy, allow_zero_length) else float("inf")
num_batches = _parse_num_batches(stage, length, trainer.limit_train_batches)
limits.append(num_batches)
Expand All @@ -260,17 +270,18 @@ def setup_data(self) -> None:

self._data_fetcher = _select_data_fetcher(trainer, RunningStage.TRAINING)
self._data_fetcher.setup(combined_loader)
iter(self._data_fetcher) # creates the iterator inside the fetcher
iter(self._data_fetcher) # Creates the iterator inside the fetcher
max_batches = sized_len(combined_loader)
self.max_batches = max_batches if max_batches is not None else float("inf")
has_len_all_ranks_ = has_len_all_ranks(combined_loader, trainer.strategy, allow_zero_length)

if self.max_batches == 0:
return
return # No batches to process

# store epoch of dataloader reset for reload_dataloaders_every_n_epochs
# Store epoch of dataloader reset for reload_dataloaders_every_n_epochs
self._last_train_dl_reload_epoch = trainer.current_epoch

# Validation check interval logic
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A comment is only useful it if sheds more light on the underlying mechanisms.
Adding "Validation check interval logic" on top of if isinstance(trainer.val_check_interval, int) ... is pretty redundant.

What does this logic do exactly is what will make the comment useful for readers. Conversely it's line noise.

if isinstance(trainer.val_check_interval, int):
trainer.val_check_batch = trainer.val_check_interval
if trainer.val_check_batch > self.max_batches and trainer.check_val_every_n_epoch is not None:
Expand All @@ -294,6 +305,7 @@ def setup_data(self) -> None:
trainer.val_check_batch = int(self.max_batches * trainer.val_check_interval)
trainer.val_check_batch = max(1, trainer.val_check_batch)

# Warning for logging intervals
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Warning what for what reason?

if trainer.loggers and self.max_batches < trainer.log_every_n_steps and not trainer.fast_dev_run:
rank_zero_warn(
f"The number of training batches ({self.max_batches}) is smaller than the logging interval"
Expand All @@ -312,30 +324,30 @@ def reset(self) -> None:

def on_run_start(self) -> None:
"""Calls the ``on_train_start`` hook."""
# update the current_epoch in-case of checkpoint reload
# Update the current_epoch in case of checkpoint reload
if not self._iteration_based_training():
self.epoch_progress.current.completed = self.epoch_progress.current.processed

trainer = self.trainer

# reload the evaluation dataloaders too for proper display in the progress bar
# Reload the evaluation dataloaders for proper display in the progress bar
if self.epoch_loop._should_check_val_epoch() and trainer.val_dataloaders is None:
trainer.validating = True
self.epoch_loop.val_loop.setup_data()
self.epoch_loop.val_loop.setup_data() # Setup validation data
trainer.training = True

call._call_callback_hooks(trainer, "on_train_start")
call._call_lightning_module_hook(trainer, "on_train_start")
call._call_strategy_hook(trainer, "on_train_start")

def on_advance_start(self) -> None:
"""Prepares the dataloader for training and calls the hook ``on_train_epoch_start``"""
"""Prepares the dataloader for training and calls the hook ``on_train_epoch_start``."""
trainer = self.trainer

# might need to setup data again depending on `trainer.reload_dataloaders_every_n_epochs`
self.setup_data()
# Might need to setup data again depending on `trainer.reload_dataloaders_every_n_epochs`
self.setup_data() # This ensures data is fresh for the current epoch

# update the epoch value for all samplers
# Update the epoch value for all samplers
assert self._combined_loader is not None
for i, dl in enumerate(self._combined_loader.flattened):
_set_sampler_epoch(dl, self.epoch_progress.current.processed)
Expand Down
Loading