Skip to content

stateful dataloaders do not load their state_dict if self.trainer.estimated_stepping_batches called beforehand #20550

@ItamarKanter

Description

@ItamarKanter

Bug description

stateful dataloaders do not load their stat_dict and restore their state if trainer.estimated_stepping_batches is called
The situation pops up when one uses lr_scheduler.OneCycleLR which requires the total_steps

What version are you seeing the problem on?

v2.5

How to reproduce the bug

this code is adopted from PL test_resume_mid_epoch_warning

from pathlib import Path
import torch
from lightning.pytorch.demos.boring_classes import BoringModel
from lightning.pytorch.utilities import CombinedLoader
from lightning.pytorch import Trainer


class NotStatefulIterable:
    def __init__(self, start=0):
        self.index = start

    def __iter__(self):
        for i in range(self.index, len(self)):
            self.index = i
            yield self.index

    def __len__(self):
        return 10


class StatefulIterable(NotStatefulIterable):
    def state_dict(self):
        return {"index": self.index}

    def load_state_dict(self, state_dict):
        self.index = state_dict["index"] + 1

    # Single stateful DataLoader


train_dataloader_factory = lambda: CombinedLoader(StatefulIterable())
has_state = True
batches_before = [0, 1]
batches_after = [2, 3]

tmp_path = Path(".")


class DummyModel(BoringModel):
    def __init__(self):
        super().__init__()
        self.seen_data = []

    def configure_optimizers(self):
        total_steps = int(
            self.trainer.estimated_stepping_batches
        )  # for torch.optim.lr_scheduler.OneCycleLR

    def training_step(self, batch, batch_idx):
        self.seen_data.append(batch)
        print(batch)

    def train_dataloader(self):
        return train_dataloader_factory()


trainer_kwargs = {
    "default_root_dir": tmp_path,
    "accelerator": "cpu",
    "enable_checkpointing": False,
    "enable_model_summary": False,
    "enable_progress_bar": False,
    "logger": False,
    "num_sanity_val_steps": 0,
}

# Train for 2 steps
model = DummyModel()
trainer = Trainer(**trainer_kwargs, max_steps=2, max_epochs=10)
trainer.fit(model)
assert model.seen_data == batches_before

# Save a checkpoint
trainer.save_checkpoint(tmp_path / "checkpoint.ckpt")
checkpoint = torch.load(tmp_path / "checkpoint.ckpt")
if has_state:
    assert checkpoint["loops"]["fit_loop"]["state_dict"]["combined_loader"]
else:
    assert "combined_loader" not in checkpoint["loops"]["fit_loop"]["state_dict"]

# Restore training from step 2 and continue 2 more steps
model = DummyModel()
trainer = Trainer(**trainer_kwargs, max_steps=4, max_epochs=10)
trainer.fit(model, ckpt_path=(tmp_path / "checkpoint.ckpt"))
assert model.seen_data == batches_after

Error messages and logs

assert model.seen_data == batches_after
AssertionError

Environment

Current environment
#- PyTorch Lightning Version (e.g., 2.5.0):
#- PyTorch Version (e.g., 2.5):
#- Python version (e.g., 3.12):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source):

More info

It has to do with trainer.estimated_stepping_batches that invokes self.fit_loop.setup_data() during strategy.setup and than when self.fit_loop.setup_data() invoked again in self._run_stage() it skips the state_dict loading

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions