-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Open
Labels
Description
Bug description
When EarlyStopping callback signals Trainer to stop, the message displayed is always wrong.
How to reproduce
Any case where EarlyStopping is triggered can reproduce this.
Cause
In fit_loop.done, the order of if statements are wrong.
@property
def done(self) -> bool:
"""Evaluates when to leave the loop."""
if self.max_batches == 0:
rank_zero_info("`Trainer.fit` stopped: No training batches.")
return True
# TODO: Move track steps inside training loop and move part of these condition inside training loop
stop_steps = _is_max_limit_reached(self.epoch_loop.global_step, self.max_steps)
if stop_steps:
rank_zero_info(f"`Trainer.fit` stopped: `max_steps={self.max_steps!r}` reached.")
return True
# `processed` is increased before `on_train_epoch_end`, the hook where checkpoints are typically saved.
# we use it here because the checkpoint data won't have `completed` increased yet
assert isinstance(self.max_epochs, int)
stop_epochs = _is_max_limit_reached(self.epoch_progress.current.processed, self.max_epochs)
if stop_epochs:
# in case they are not equal, override so `trainer.current_epoch` has the expected value
self.epoch_progress.current.completed = self.epoch_progress.current.processed
rank_zero_info(f"`Trainer.fit` stopped: `max_epochs={self.max_epochs!r}` reached.")
return True
if self.trainer.should_stop and self._can_stop_early:
rank_zero_debug("`Trainer.fit` stopped: `trainer.should_stop` was set.")
return True
return False
Solution
Modify the order of if statements.
What version are you seeing the problem on?
v2.5
Reproduced in studio
No response
How to reproduce the bug
Code of Module
class IrisClassifier(MLPModule):
def __init__(self, **kwargs):
super().__init__()
self.model = torch.nn.Sequential(
torch.nn.Linear(4, 16), torch.nn.ReLU(), torch.nn.Linear(16, 3)
)
self.loss_fn = torch.nn.CrossEntropyLoss()
self.automatic_optimization = False # Disable automatic optimization
def on_train_start(self):
"""Check if running in a distributed environment."""
comprehensive_distributed_check()
print("distributed strategy:", self.trainer.strategy)
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_idx):
x = torch.stack(
[
batch["sepal_length"],
batch["sepal_width"],
batch["petal_length"],
batch["petal_width"],
],
dim=1,
).float() # shape: [batch, 4]
species_map = {"setosa": 0, "versicolor": 1, "virginica": 2}
y = torch.tensor(
[species_map[s] for s in batch["species"]],
dtype=torch.long,
device=self.device,
)
logits = self(x)
loss = self.loss_fn(logits, y)
acc = (logits.argmax(dim=1) == y).float().mean()
opt = self.optimizers()
opt.zero_grad()
self.manual_backward(loss)
opt.step()
return loss
def validation_step(self, batch, batch_idx):
x = torch.stack(
[
batch["sepal_length"],
batch["sepal_width"],
batch["petal_length"],
batch["petal_width"],
],
dim=1,
).float()
species_map = {"setosa": 0, "versicolor": 1, "virginica": 2}
y = torch.tensor(
[species_map[s] for s in batch["species"]],
dtype=torch.long,
device=self.device,
)
logits = self(x)
loss = self.loss_fn(logits, y)
acc = (logits.argmax(dim=1) == y).float().mean()
print(f"on Rank{self.trainer.global_rank} the batch_idx is {batch_idx} the val_acc is {acc}")
self.log("Validation/val_loss", loss, prog_bar=True)
self.log("Validation/val_acc", acc, prog_bar=True, sync_dist=True)
Configuration I used
max_epochs=1,
EarlyStopping(monitor="Validation/val_acc", mode="max", min_delta=100, patience=2),
val_check_interval=2,
limit_val_batches=10,
Error messages and logs
# Error messages and logs here please
Monitored metric Validation/val_acc did not improve in the last 2 records. Best score: 0.317. Signaling Trainer to stop.
`Trainer.fit` stopped: `max_epochs=1` reached.
The message is supposed to be Trainer.fit
stopped: trainer.should_stop
was set.
Environment
Current environment
#- PyTorch Lightning Version (e.g., 2.5.0): 2.5.1.post0
#- PyTorch Version (e.g., 2.5): 2.7.0
#- Python version (e.g., 3.12): 3.9.6
#- OS (e.g., Linux): macOS
#- CUDA/cuDNN version: null
#- GPU models and configuration: null
#- How you installed Lightning(`conda`, `pip`, source): uv pip
More info
No response