Skip to content

LR is not updated when on_train_batch_start returns -1 #21296

@LTMeyer

Description

@LTMeyer

Bug description

Context

I'm training in a DDP setting on IterableDataset instances that can return different number of batches. To avoid any deadlock issue due to one GPU receiving more batches than others, I use a dataloader that yields a signal when it has exhausted the original iterator. I use this signal to stop the training loop by returning -1 in the on_train_batch_start hook.

def on_train_batch_start(self, batch: Batch | ExhaustedDataloaderSignal, batch_idx: int):
        # Check if all dataloaders have data
        # If not, stop training for current epoch to avoid deadlock issues between GPUs
        has_data = True
        if isinstance(batch, ExhaustedDataloaderSignal):
            logger.debug(f"Rank {self.local_rank} has exhausted data on batch {batch_idx}.")
            has_data = False
        if self.trainer.world_size > 1:
            # Convert has_data as tensor for synchronization
            has_data = torch.tensor(has_data, dtype=torch.bool)
            # Synchronize has_data across ranks
            has_data = self.all_gather(has_data)
            # Check if all ranks have data
            has_data = bool(has_data.all().cpu())
        if not has_data:
            # Stop training for current epoch
            return -1
       ...

Each training epoch finishes as expected, except the learning rate is not updated according to the scheduler that has been specified.

Expected behaviour

The epoch terminates and the learning rate are updated.

Related documentation

The documentation for the on_train_batch_start reads:

If you return -1 here, you will skip training for the rest of the current epoch.

It is not clear if this implies that learning rate scheduler updates occurring at the end of the epoch will be also skipped.

Likely origin of the problem

While checking the code, one can see that if the hook returns -1, a StopIteration is raised, and the learning rate scheduler updates are never called.

What version are you seeing the problem on?

v2.5

Reproduced in studio

No response

How to reproduce the bug

Define a lightning module with a learning rate scheduler with 1 epoch frequency and on_train_batch_start returning -1 during an epoch.

Error messages and logs

No error message is produced. LR simply doesn't get updated.

Environment

Current environment
#- PyTorch Lightning Version (e.g., 2.5.0): 2.5.5
#- PyTorch Version (e.g., 2.5): 2.8.0+cu126
#- Python version (e.g., 3.12): 3.13
#- OS (e.g., Linux): Linux
#- How you installed Lightning(`conda`, `pip`, source): uv pip

More info

No response

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions