-
Notifications
You must be signed in to change notification settings - Fork 3.6k
docs(fit_loop): Improve documentation and comments #20426
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -211,28 +211,37 @@ def run(self) -> None: | |
self.on_run_end() | ||
|
||
def setup_data(self) -> None: | ||
"""Sets up the training data loaders. | ||
|
||
This method checks if the data loader needs to be reloaded based on the current epoch and the specified | ||
conditions. It initializes the combined loader for training and handles overfitting scenarios. | ||
|
||
""" | ||
if self._combined_loader is not None and not self._should_reload_train_dl: | ||
return | ||
return # No need to reload if already set up and not required | ||
|
||
trainer = self.trainer | ||
pl_module = trainer.lightning_module | ||
if trainer.limit_train_batches == 0 or not is_overridden("training_step", pl_module): | ||
return | ||
return # Skip setup if no training batches or training step is not overridden | ||
|
||
log.debug(f"{self.__class__.__name__}: resetting train dataloader") | ||
|
||
source = self._data_source | ||
train_dataloader = _request_dataloader(source) | ||
trainer.strategy.barrier("train_dataloader()") | ||
|
||
# Initialize combined loader | ||
if not isinstance(train_dataloader, CombinedLoader): | ||
combined_loader = CombinedLoader(train_dataloader, "max_size_cycle") | ||
else: | ||
combined_loader = train_dataloader | ||
|
||
# Handle overfitting scenarios | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is not correct. Overfit batches is a feature of Trainer to purposefully overfit to spot potential bugs in users code. |
||
if trainer.overfit_batches > 0: | ||
_resolve_overfit_batches(combined_loader, mode=RunningStage.TRAINING) | ||
|
||
# Process each data loader | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For the comment to be useful we need to expand what "Process" means |
||
trainer_fn = TrainerFn.FITTING | ||
stage = RunningStage.TRAINING | ||
dataloaders = [] | ||
|
@@ -243,13 +252,14 @@ def setup_data(self) -> None: | |
combined_loader.flattened = dataloaders | ||
self._combined_loader = combined_loader | ||
|
||
# Allow zero-length dataloaders if specified | ||
allow_zero_length = pl_module.allow_zero_length_dataloader_with_multiple_devices | ||
if trainer.datamodule is not None: | ||
allow_zero_length |= trainer.datamodule.allow_zero_length_dataloader_with_multiple_devices | ||
|
||
limits = [] | ||
for dl in combined_loader.flattened: | ||
# determine number of batches | ||
# Determine number of batches | ||
length = len(dl) if has_len_all_ranks(dl, trainer.strategy, allow_zero_length) else float("inf") | ||
num_batches = _parse_num_batches(stage, length, trainer.limit_train_batches) | ||
limits.append(num_batches) | ||
|
@@ -260,17 +270,18 @@ def setup_data(self) -> None: | |
|
||
self._data_fetcher = _select_data_fetcher(trainer, RunningStage.TRAINING) | ||
self._data_fetcher.setup(combined_loader) | ||
iter(self._data_fetcher) # creates the iterator inside the fetcher | ||
iter(self._data_fetcher) # Creates the iterator inside the fetcher | ||
max_batches = sized_len(combined_loader) | ||
self.max_batches = max_batches if max_batches is not None else float("inf") | ||
has_len_all_ranks_ = has_len_all_ranks(combined_loader, trainer.strategy, allow_zero_length) | ||
|
||
if self.max_batches == 0: | ||
return | ||
return # No batches to process | ||
|
||
# store epoch of dataloader reset for reload_dataloaders_every_n_epochs | ||
# Store epoch of dataloader reset for reload_dataloaders_every_n_epochs | ||
self._last_train_dl_reload_epoch = trainer.current_epoch | ||
|
||
# Validation check interval logic | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A comment is only useful it if sheds more light on the underlying mechanisms. What does this logic do exactly is what will make the comment useful for readers. Conversely it's line noise. |
||
if isinstance(trainer.val_check_interval, int): | ||
trainer.val_check_batch = trainer.val_check_interval | ||
if trainer.val_check_batch > self.max_batches and trainer.check_val_every_n_epoch is not None: | ||
|
@@ -294,6 +305,7 @@ def setup_data(self) -> None: | |
trainer.val_check_batch = int(self.max_batches * trainer.val_check_interval) | ||
trainer.val_check_batch = max(1, trainer.val_check_batch) | ||
|
||
# Warning for logging intervals | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Warning what for what reason? |
||
if trainer.loggers and self.max_batches < trainer.log_every_n_steps and not trainer.fast_dev_run: | ||
rank_zero_warn( | ||
f"The number of training batches ({self.max_batches}) is smaller than the logging interval" | ||
|
@@ -312,30 +324,30 @@ def reset(self) -> None: | |
|
||
def on_run_start(self) -> None: | ||
"""Calls the ``on_train_start`` hook.""" | ||
# update the current_epoch in-case of checkpoint reload | ||
# Update the current_epoch in case of checkpoint reload | ||
if not self._iteration_based_training(): | ||
self.epoch_progress.current.completed = self.epoch_progress.current.processed | ||
|
||
trainer = self.trainer | ||
|
||
# reload the evaluation dataloaders too for proper display in the progress bar | ||
# Reload the evaluation dataloaders for proper display in the progress bar | ||
if self.epoch_loop._should_check_val_epoch() and trainer.val_dataloaders is None: | ||
trainer.validating = True | ||
self.epoch_loop.val_loop.setup_data() | ||
self.epoch_loop.val_loop.setup_data() # Setup validation data | ||
trainer.training = True | ||
|
||
call._call_callback_hooks(trainer, "on_train_start") | ||
call._call_lightning_module_hook(trainer, "on_train_start") | ||
call._call_strategy_hook(trainer, "on_train_start") | ||
|
||
def on_advance_start(self) -> None: | ||
"""Prepares the dataloader for training and calls the hook ``on_train_epoch_start``""" | ||
"""Prepares the dataloader for training and calls the hook ``on_train_epoch_start``.""" | ||
trainer = self.trainer | ||
|
||
# might need to setup data again depending on `trainer.reload_dataloaders_every_n_epochs` | ||
self.setup_data() | ||
# Might need to setup data again depending on `trainer.reload_dataloaders_every_n_epochs` | ||
self.setup_data() # This ensures data is fresh for the current epoch | ||
|
||
# update the epoch value for all samplers | ||
# Update the epoch value for all samplers | ||
assert self._combined_loader is not None | ||
for i, dl in enumerate(self._combined_loader.flattened): | ||
_set_sampler_epoch(dl, self.epoch_progress.current.processed) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This needs to be extended: from now on we assume we deal with a
combined_loader
, in case of a single dataloader we treat it as acombined_loader
holding a single data loader.