-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Open
Labels
Description
Bug description
When using the built-in ThroughputMonitor callback, training can crash with:
ValueError: Expected the value to increase, last: X, current: X
This happens because in on_validation_end
Lightning currently computes training_finished
as:
training_finished = self._t0s[RunningStage.TRAINING] + sum(self._throughputs[RunningStage.TRAINING]._time)
But _time
already stores cumulative elapsed times per batch (monotonic, increasing).
Summing them again produces a huge, incorrect number, which eventually makes the internal _time.append
assertion fail (Expected the value to increase).
What version are you seeing the problem on?
v2.5
Reproduced in studio
No response
How to reproduce the bug
import os
from torch import optim, nn, utils
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
import lightning as L
from torch.utils.data import random_split
from lightning.pytorch.callbacks import ThroughputMonitor
from lightning.pytorch.trainer.trainer import Trainer
encoder = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))
decoder = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))
class CustomThroughputMonitor(ThroughputMonitor):
def __init__(self, batch_size_fn=None, *args, **kwargs) -> None:
super().__init__(batch_size_fn=batch_size_fn, *args, **kwargs)
def setup(self, trainer: Trainer, pl_module, stage):
self.batch_size_fn = lambda batch: 1
return super().setup(trainer, pl_module, stage)
class LitAutoEncoder(L.LightningModule):
def __init__(self, encoder, decoder):
super().__init__()
self.encoder = encoder
self.decoder = decoder
def training_step(self, batch, batch_idx):
x, _ = batch
x = x.view(x.size(0), -1)
z = self.encoder(x)
x_hat = self.decoder(z)
loss = nn.functional.mse_loss(x_hat, x)
self.log("train_loss", loss)
return loss
def validation_step(self, batch, batch_idx):
x, _ = batch
x = x.view(x.size(0), -1)
z = self.encoder(x)
x_hat = self.decoder(z)
val_loss = nn.functional.mse_loss(x_hat, x)
self.log("val_loss", val_loss)
def configure_optimizers(self):
return optim.Adam(self.parameters(), lr=1e-3)
autoencoder = LitAutoEncoder(encoder, decoder)
dataset = MNIST(os.getcwd(), download=True, transform=ToTensor())
train_set, val_set = random_split(dataset, [55000, 5000])
train_loader = utils.data.DataLoader(train_set, batch_size=64)
val_loader = utils.data.DataLoader(val_set, batch_size=64)
trainer = L.Trainer(
limit_train_batches=500,
max_epochs=20,
callbacks=[CustomThroughputMonitor()],
log_every_n_steps=100,
)
trainer.fit(model=autoencoder, train_dataloaders=train_loader, val_dataloaders=val_loader)
Error messages and logs
raise ValueError(f"Expected the value to increase, last: {last}, current: {x}")
ValueError: Expected the value to increase, last: 358569066962426.94, current: 358569066962426.94
Environment
No response
More info
Proposed fix:
Replace the sum() with just the last elapsed value
@rank_zero_only
def on_validation_end(self, trainer, *_):
if trainer.sanity_checking or trainer.state.fn != TrainerFn.FITTING:
return
training_finished = self._t0s[RunningStage.TRAINING] + self._throughputs[RunningStage.TRAINING]._time[-1]
time_between_train_and_val = self._t0s[RunningStage.VALIDATING] - training_finished
val_time = self._throughputs[RunningStage.VALIDATING]._time[-1]
self._t0s[RunningStage.TRAINING] += time_between_train_and_val + val_time