Skip to content

StreamingDataLoader state is not loaded from checkpoint when resuming trainingΒ #775

@philgzl

Description

@philgzl

πŸ› Bug

The answer given in #249 suggests that Trainer seamlessly integrates with StreamingDataLoader, such that when saving checkpoints, the state of the StreamingDataLoader is included, and when resuming training, the StreamingDataLoader state is loaded. However this does not seem to be the case. See example code below.

If this is expected, i.e. if users are required to implement extra logic in order to achieve this behavior, then an example of how to achieve this behavior should be added in the README, as suggested in the original issue.

To Reproduce

import torch
import torch.nn as nn
from lightning import LightningDataModule, LightningModule, Trainer
from litdata import StreamingDataLoader, StreamingDataset
from litdata.streaming import Cache


class MyTrainingException(Exception):
    pass


class MyStreamingDataLoader(StreamingDataLoader):
    def load_state_dict(self, obj):
        raise ValueError("StreamingDataLoader.load_state_dict called! Hooray!")


class MyLightningDataModule(LightningDataModule):
    def prepare_data(self):
        cache = Cache("temp/", chunk_size=1)
        dset_len = 10
        for i in range(dset_len):
            cache[i] = i
        cache.done()
        cache.merge()

    def setup(self, stage):
        self.dset = StreamingDataset("temp/")

    def train_dataloader(self):
        return MyStreamingDataLoader(self.dset)


class MyLightningModule(LightningModule):
    def __init__(self):
        super().__init__()
        self.param = nn.Parameter(torch.randn(1))

    def training_step(self, batch):
        if self.current_epoch == 2:
            raise MyTrainingException
        return batch * self.param

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters())


dm = MyLightningDataModule()
model = MyLightningModule()
trainer = Trainer(max_epochs=10, default_root_dir="temp/")
try:
    trainer.fit(model, dm)
except MyTrainingException:
    print("Training crashed as expected on epoch 2.")
# resume training
dm = MyLightningDataModule()
model = MyLightningModule()
trainer = Trainer(max_epochs=10, default_root_dir="temp/")
# the call below should raise ValueError("StreamingDataLoader.load_state_dict called! Hooray!")
# but it raises MyTrainingException again, which means the dataloader state is not loaded from the checkpoint
trainer.fit(model, dm, ckpt_path="temp/lightning_logs/version_0/checkpoints/epoch=1-step=20.ckpt")

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workinghelp wantedExtra attention is needed

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions