Skip to content

Display max_epoch=xxx reached message when EarlyStopping triggers #21031

@GoldenStain

Description

@GoldenStain

Bug description

When EarlyStopping callback signals Trainer to stop, the message displayed is always wrong.

How to reproduce

Any case where EarlyStopping is triggered can reproduce this.

Cause

In fit_loop.done, the order of if statements are wrong.

    @property
    def done(self) -> bool:
        """Evaluates when to leave the loop."""
        if self.max_batches == 0:
            rank_zero_info("`Trainer.fit` stopped: No training batches.")
            return True

        # TODO: Move track steps inside training loop and move part of these condition inside training loop
        stop_steps = _is_max_limit_reached(self.epoch_loop.global_step, self.max_steps)
        if stop_steps:
            rank_zero_info(f"`Trainer.fit` stopped: `max_steps={self.max_steps!r}` reached.")
            return True

        # `processed` is increased before `on_train_epoch_end`, the hook where checkpoints are typically saved.
        # we use it here because the checkpoint data won't have `completed` increased yet
        assert isinstance(self.max_epochs, int)
        stop_epochs = _is_max_limit_reached(self.epoch_progress.current.processed, self.max_epochs)
        if stop_epochs:
            # in case they are not equal, override so `trainer.current_epoch` has the expected value
            self.epoch_progress.current.completed = self.epoch_progress.current.processed
            rank_zero_info(f"`Trainer.fit` stopped: `max_epochs={self.max_epochs!r}` reached.")
            return True

        if self.trainer.should_stop and self._can_stop_early:
            rank_zero_debug("`Trainer.fit` stopped: `trainer.should_stop` was set.")
            return True

        return False

Solution

Modify the order of if statements.

What version are you seeing the problem on?

v2.5

Reproduced in studio

No response

How to reproduce the bug

Code of Module

class IrisClassifier(MLPModule):
    def __init__(self, **kwargs):
        super().__init__()
        self.model = torch.nn.Sequential(
            torch.nn.Linear(4, 16), torch.nn.ReLU(), torch.nn.Linear(16, 3)
        )
        self.loss_fn = torch.nn.CrossEntropyLoss()
        self.automatic_optimization = False  # Disable automatic optimization

    def on_train_start(self):
        """Check if running in a distributed environment."""
        comprehensive_distributed_check()
        print("distributed strategy:", self.trainer.strategy)

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):

        x = torch.stack(
            [
                batch["sepal_length"],
                batch["sepal_width"],
                batch["petal_length"],
                batch["petal_width"],
            ],
            dim=1,
        ).float()  # shape: [batch, 4]

        species_map = {"setosa": 0, "versicolor": 1, "virginica": 2}
        y = torch.tensor(
            [species_map[s] for s in batch["species"]],
            dtype=torch.long,
            device=self.device,
        )

        logits = self(x)
        loss = self.loss_fn(logits, y)
        acc = (logits.argmax(dim=1) == y).float().mean()

        opt = self.optimizers()
        opt.zero_grad()

        self.manual_backward(loss)

        opt.step()

        return loss

    def validation_step(self, batch, batch_idx):
        x = torch.stack(
            [
                batch["sepal_length"],
                batch["sepal_width"],
                batch["petal_length"],
                batch["petal_width"],
            ],
            dim=1,
        ).float()
        species_map = {"setosa": 0, "versicolor": 1, "virginica": 2}
        y = torch.tensor(
            [species_map[s] for s in batch["species"]],
            dtype=torch.long,
            device=self.device,
        )

        logits = self(x)
        loss = self.loss_fn(logits, y)
        acc = (logits.argmax(dim=1) == y).float().mean()
        print(f"on Rank{self.trainer.global_rank} the batch_idx is {batch_idx} the val_acc is {acc}")
        self.log("Validation/val_loss", loss, prog_bar=True)
        self.log("Validation/val_acc", acc, prog_bar=True, sync_dist=True)

Configuration I used

max_epochs=1,
EarlyStopping(monitor="Validation/val_acc", mode="max", min_delta=100, patience=2),
val_check_interval=2,
limit_val_batches=10,

Error messages and logs

# Error messages and logs here please
Monitored metric Validation/val_acc did not improve in the last 2 records. Best score: 0.317. Signaling Trainer to stop.
`Trainer.fit` stopped: `max_epochs=1` reached.

The message is supposed to be Trainer.fit stopped: trainer.should_stop was set.

Environment

Current environment
#- PyTorch Lightning Version (e.g., 2.5.0): 2.5.1.post0
#- PyTorch Version (e.g., 2.5): 2.7.0
#- Python version (e.g., 3.12): 3.9.6
#- OS (e.g., Linux): macOS
#- CUDA/cuDNN version: null
#- GPU models and configuration: null
#- How you installed Lightning(`conda`, `pip`, source): uv pip

More info

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions