-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Open
Labels
Description
Bug description
stateful dataloaders do not load their stat_dict and restore their state if trainer.estimated_stepping_batches is called
The situation pops up when one uses lr_scheduler.OneCycleLR which requires the total_steps
What version are you seeing the problem on?
v2.5
How to reproduce the bug
this code is adopted from PL test_resume_mid_epoch_warning
from pathlib import Path
import torch
from lightning.pytorch.demos.boring_classes import BoringModel
from lightning.pytorch.utilities import CombinedLoader
from lightning.pytorch import Trainer
class NotStatefulIterable:
def __init__(self, start=0):
self.index = start
def __iter__(self):
for i in range(self.index, len(self)):
self.index = i
yield self.index
def __len__(self):
return 10
class StatefulIterable(NotStatefulIterable):
def state_dict(self):
return {"index": self.index}
def load_state_dict(self, state_dict):
self.index = state_dict["index"] + 1
# Single stateful DataLoader
train_dataloader_factory = lambda: CombinedLoader(StatefulIterable())
has_state = True
batches_before = [0, 1]
batches_after = [2, 3]
tmp_path = Path(".")
class DummyModel(BoringModel):
def __init__(self):
super().__init__()
self.seen_data = []
def configure_optimizers(self):
total_steps = int(
self.trainer.estimated_stepping_batches
) # for torch.optim.lr_scheduler.OneCycleLR
def training_step(self, batch, batch_idx):
self.seen_data.append(batch)
print(batch)
def train_dataloader(self):
return train_dataloader_factory()
trainer_kwargs = {
"default_root_dir": tmp_path,
"accelerator": "cpu",
"enable_checkpointing": False,
"enable_model_summary": False,
"enable_progress_bar": False,
"logger": False,
"num_sanity_val_steps": 0,
}
# Train for 2 steps
model = DummyModel()
trainer = Trainer(**trainer_kwargs, max_steps=2, max_epochs=10)
trainer.fit(model)
assert model.seen_data == batches_before
# Save a checkpoint
trainer.save_checkpoint(tmp_path / "checkpoint.ckpt")
checkpoint = torch.load(tmp_path / "checkpoint.ckpt")
if has_state:
assert checkpoint["loops"]["fit_loop"]["state_dict"]["combined_loader"]
else:
assert "combined_loader" not in checkpoint["loops"]["fit_loop"]["state_dict"]
# Restore training from step 2 and continue 2 more steps
model = DummyModel()
trainer = Trainer(**trainer_kwargs, max_steps=4, max_epochs=10)
trainer.fit(model, ckpt_path=(tmp_path / "checkpoint.ckpt"))
assert model.seen_data == batches_after
Error messages and logs
assert model.seen_data == batches_after
AssertionError
Environment
Current environment
#- PyTorch Lightning Version (e.g., 2.5.0):
#- PyTorch Version (e.g., 2.5):
#- Python version (e.g., 3.12):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source):
More info
It has to do with trainer.estimated_stepping_batches that invokes self.fit_loop.setup_data() during strategy.setup and than when self.fit_loop.setup_data() invoked again in self._run_stage() it skips the state_dict loading