-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Description
Bug description
Context
I'm training in a DDP setting on IterableDataset instances that can return different number of batches. To avoid any deadlock issue due to one GPU receiving more batches than others, I use a dataloader that yields a signal when it has exhausted the original iterator. I use this signal to stop the training loop by returning -1 in the on_train_batch_start
hook.
def on_train_batch_start(self, batch: Batch | ExhaustedDataloaderSignal, batch_idx: int):
# Check if all dataloaders have data
# If not, stop training for current epoch to avoid deadlock issues between GPUs
has_data = True
if isinstance(batch, ExhaustedDataloaderSignal):
logger.debug(f"Rank {self.local_rank} has exhausted data on batch {batch_idx}.")
has_data = False
if self.trainer.world_size > 1:
# Convert has_data as tensor for synchronization
has_data = torch.tensor(has_data, dtype=torch.bool)
# Synchronize has_data across ranks
has_data = self.all_gather(has_data)
# Check if all ranks have data
has_data = bool(has_data.all().cpu())
if not has_data:
# Stop training for current epoch
return -1
...
Each training epoch finishes as expected, except the learning rate is not updated according to the scheduler that has been specified.
Expected behaviour
The epoch terminates and the learning rate are updated.
Related documentation
The documentation for the on_train_batch_start
reads:
If you return -1 here, you will skip training for the rest of the current epoch.
It is not clear if this implies that learning rate scheduler updates occurring at the end of the epoch will be also skipped.
Likely origin of the problem
While checking the code, one can see that if the hook returns -1
, a StopIteration
is raised, and the learning rate scheduler updates are never called.
What version are you seeing the problem on?
v2.5
Reproduced in studio
No response
How to reproduce the bug
Define a lightning module with a learning rate scheduler with 1 epoch frequency and on_train_batch_start
returning -1 during an epoch.
Error messages and logs
No error message is produced. LR simply doesn't get updated.
Environment
Current environment
#- PyTorch Lightning Version (e.g., 2.5.0): 2.5.5
#- PyTorch Version (e.g., 2.5): 2.8.0+cu126
#- Python version (e.g., 3.12): 3.13
#- OS (e.g., Linux): Linux
#- How you installed Lightning(`conda`, `pip`, source): uv pip
More info
No response