Skip to content

Commit 692e9df

Browse files
committed
docs(fit_loop): Improve documentation and comments
- Added comments to clarify the purpose of methods in fit_loop.py. - Explained the logic behind multiple calls to setup_data(). - Documented handling of validation checks and logging intervals. These changes aim to enhance code readability and maintainability for future developers.
1 parent 3627c5b commit 692e9df

File tree

1 file changed

+40
-13
lines changed

1 file changed

+40
-13
lines changed

src/lightning/pytorch/loops/fit_loop.py

Lines changed: 40 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,22 @@
3434
_request_dataloader,
3535
_resolve_overfit_batches,
3636
)
37+
38+
39+
40+
41+
42+
43+
44+
45+
46+
47+
48+
49+
50+
51+
52+
3753
from lightning.pytorch.trainer.connectors.logger_connector.result import _ResultCollection
3854
from lightning.pytorch.trainer.states import RunningStage, TrainerFn
3955
from lightning.pytorch.utilities.combined_loader import _SUPPORTED_MODES, CombinedLoader
@@ -211,28 +227,36 @@ def run(self) -> None:
211227
self.on_run_end()
212228

213229
def setup_data(self) -> None:
230+
"""Sets up the training data loaders.
231+
232+
This method checks if the data loader needs to be reloaded based on the current epoch and the specified
233+
conditions. It initializes the combined loader for training and handles overfitting scenarios.
234+
"""
214235
if self._combined_loader is not None and not self._should_reload_train_dl:
215-
return
236+
return # No need to reload if already set up and not required
216237

217238
trainer = self.trainer
218239
pl_module = trainer.lightning_module
219240
if trainer.limit_train_batches == 0 or not is_overridden("training_step", pl_module):
220-
return
241+
return # Skip setup if no training batches or training step is not overridden
221242

222243
log.debug(f"{self.__class__.__name__}: resetting train dataloader")
223244

224245
source = self._data_source
225246
train_dataloader = _request_dataloader(source)
226247
trainer.strategy.barrier("train_dataloader()")
227248

249+
# Initialize combined loader
228250
if not isinstance(train_dataloader, CombinedLoader):
229251
combined_loader = CombinedLoader(train_dataloader, "max_size_cycle")
230252
else:
231253
combined_loader = train_dataloader
232254

255+
# Handle overfitting scenarios
233256
if trainer.overfit_batches > 0:
234257
_resolve_overfit_batches(combined_loader, mode=RunningStage.TRAINING)
235258

259+
# Process each data loader
236260
trainer_fn = TrainerFn.FITTING
237261
stage = RunningStage.TRAINING
238262
dataloaders = []
@@ -243,13 +267,14 @@ def setup_data(self) -> None:
243267
combined_loader.flattened = dataloaders
244268
self._combined_loader = combined_loader
245269

270+
# Allow zero-length dataloaders if specified
246271
allow_zero_length = pl_module.allow_zero_length_dataloader_with_multiple_devices
247272
if trainer.datamodule is not None:
248273
allow_zero_length |= trainer.datamodule.allow_zero_length_dataloader_with_multiple_devices
249274

250275
limits = []
251276
for dl in combined_loader.flattened:
252-
# determine number of batches
277+
# Determine number of batches
253278
length = len(dl) if has_len_all_ranks(dl, trainer.strategy, allow_zero_length) else float("inf")
254279
num_batches = _parse_num_batches(stage, length, trainer.limit_train_batches)
255280
limits.append(num_batches)
@@ -260,17 +285,18 @@ def setup_data(self) -> None:
260285

261286
self._data_fetcher = _select_data_fetcher(trainer, RunningStage.TRAINING)
262287
self._data_fetcher.setup(combined_loader)
263-
iter(self._data_fetcher) # creates the iterator inside the fetcher
288+
iter(self._data_fetcher) # Creates the iterator inside the fetcher
264289
max_batches = sized_len(combined_loader)
265290
self.max_batches = max_batches if max_batches is not None else float("inf")
266291
has_len_all_ranks_ = has_len_all_ranks(combined_loader, trainer.strategy, allow_zero_length)
267292

268293
if self.max_batches == 0:
269-
return
294+
return # No batches to process
270295

271-
# store epoch of dataloader reset for reload_dataloaders_every_n_epochs
296+
# Store epoch of dataloader reset for reload_dataloaders_every_n_epochs
272297
self._last_train_dl_reload_epoch = trainer.current_epoch
273298

299+
# Validation check interval logic
274300
if isinstance(trainer.val_check_interval, int):
275301
trainer.val_check_batch = trainer.val_check_interval
276302
if trainer.val_check_batch > self.max_batches and trainer.check_val_every_n_epoch is not None:
@@ -294,6 +320,7 @@ def setup_data(self) -> None:
294320
trainer.val_check_batch = int(self.max_batches * trainer.val_check_interval)
295321
trainer.val_check_batch = max(1, trainer.val_check_batch)
296322

323+
# Warning for logging intervals
297324
if trainer.loggers and self.max_batches < trainer.log_every_n_steps and not trainer.fast_dev_run:
298325
rank_zero_warn(
299326
f"The number of training batches ({self.max_batches}) is smaller than the logging interval"
@@ -312,30 +339,30 @@ def reset(self) -> None:
312339

313340
def on_run_start(self) -> None:
314341
"""Calls the ``on_train_start`` hook."""
315-
# update the current_epoch in-case of checkpoint reload
342+
# Update the current_epoch in case of checkpoint reload
316343
if not self._iteration_based_training():
317344
self.epoch_progress.current.completed = self.epoch_progress.current.processed
318345

319346
trainer = self.trainer
320347

321-
# reload the evaluation dataloaders too for proper display in the progress bar
348+
# Reload the evaluation dataloaders for proper display in the progress bar
322349
if self.epoch_loop._should_check_val_epoch() and trainer.val_dataloaders is None:
323350
trainer.validating = True
324-
self.epoch_loop.val_loop.setup_data()
351+
self.epoch_loop.val_loop.setup_data() # Setup validation data
325352
trainer.training = True
326353

327354
call._call_callback_hooks(trainer, "on_train_start")
328355
call._call_lightning_module_hook(trainer, "on_train_start")
329356
call._call_strategy_hook(trainer, "on_train_start")
330357

331358
def on_advance_start(self) -> None:
332-
"""Prepares the dataloader for training and calls the hook ``on_train_epoch_start``"""
359+
"""Prepares the dataloader for training and calls the hook ``on_train_epoch_start``."""
333360
trainer = self.trainer
334361

335-
# might need to setup data again depending on `trainer.reload_dataloaders_every_n_epochs`
336-
self.setup_data()
362+
# Might need to setup data again depending on `trainer.reload_dataloaders_every_n_epochs`
363+
self.setup_data() # This ensures data is fresh for the current epoch
337364

338-
# update the epoch value for all samplers
365+
# Update the epoch value for all samplers
339366
assert self._combined_loader is not None
340367
for i, dl in enumerate(self._combined_loader.flattened):
341368
_set_sampler_epoch(dl, self.epoch_progress.current.processed)

0 commit comments

Comments
 (0)