Skip to content

Bug in ThroughputMonitor.on_validation_end: using sum() instead of last value can corrupt #21257

@itzhakstern

Description

@itzhakstern

Bug description

When using the built-in ThroughputMonitor callback, training can crash with:
ValueError: Expected the value to increase, last: X, current: X

This happens because in on_validation_end Lightning currently computes training_finished as:
training_finished = self._t0s[RunningStage.TRAINING] + sum(self._throughputs[RunningStage.TRAINING]._time)

But _time already stores cumulative elapsed times per batch (monotonic, increasing).
Summing them again produces a huge, incorrect number, which eventually makes the internal _time.append assertion fail (Expected the value to increase).

What version are you seeing the problem on?

v2.5

Reproduced in studio

No response

How to reproduce the bug

import os
from torch import optim, nn, utils
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
import lightning as L
from torch.utils.data import random_split
from lightning.pytorch.callbacks import ThroughputMonitor
from lightning.pytorch.trainer.trainer import Trainer

encoder = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))
decoder = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))

class CustomThroughputMonitor(ThroughputMonitor):
    def __init__(self, batch_size_fn=None, *args, **kwargs) -> None:
        super().__init__(batch_size_fn=batch_size_fn, *args, **kwargs)

    def setup(self, trainer: Trainer, pl_module, stage):
        self.batch_size_fn = lambda batch: 1
        return super().setup(trainer, pl_module, stage)

class LitAutoEncoder(L.LightningModule):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder

    def training_step(self, batch, batch_idx):
        x, _ = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = nn.functional.mse_loss(x_hat, x)
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, _ = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        val_loss = nn.functional.mse_loss(x_hat, x)
        self.log("val_loss", val_loss)

    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=1e-3)

autoencoder = LitAutoEncoder(encoder, decoder)
dataset = MNIST(os.getcwd(), download=True, transform=ToTensor())
train_set, val_set = random_split(dataset, [55000, 5000])
train_loader = utils.data.DataLoader(train_set, batch_size=64)
val_loader = utils.data.DataLoader(val_set, batch_size=64)

trainer = L.Trainer(
    limit_train_batches=500,
    max_epochs=20,
    callbacks=[CustomThroughputMonitor()],
    log_every_n_steps=100,
)

trainer.fit(model=autoencoder, train_dataloaders=train_loader, val_dataloaders=val_loader)

Error messages and logs

raise ValueError(f"Expected the value to increase, last: {last}, current: {x}")
ValueError: Expected the value to increase, last: 358569066962426.94, current: 358569066962426.94

Environment

No response

More info

Proposed fix:
Replace the sum() with just the last elapsed value

    @rank_zero_only
    def on_validation_end(self, trainer, *_):
        if trainer.sanity_checking or trainer.state.fn != TrainerFn.FITTING:
            return

        training_finished = self._t0s[RunningStage.TRAINING] + self._throughputs[RunningStage.TRAINING]._time[-1]
        time_between_train_and_val = self._t0s[RunningStage.VALIDATING] - training_finished
        val_time = self._throughputs[RunningStage.VALIDATING]._time[-1]
        self._t0s[RunningStage.TRAINING] += time_between_train_and_val + val_time

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions