-
Notifications
You must be signed in to change notification settings - Fork 84
Open
Labels
bugSomething isn't workingSomething isn't workinghelp wantedExtra attention is neededExtra attention is needed
Description
π Bug
The answer given in #249 suggests that Trainer seamlessly integrates with StreamingDataLoader, such that when saving checkpoints, the state of the StreamingDataLoader is included, and when resuming training, the StreamingDataLoader state is loaded. However this does not seem to be the case. See example code below.
If this is expected, i.e. if users are required to implement extra logic in order to achieve this behavior, then an example of how to achieve this behavior should be added in the README, as suggested in the original issue.
To Reproduce
import torch
import torch.nn as nn
from lightning import LightningDataModule, LightningModule, Trainer
from litdata import StreamingDataLoader, StreamingDataset
from litdata.streaming import Cache
class MyTrainingException(Exception):
pass
class MyStreamingDataLoader(StreamingDataLoader):
def load_state_dict(self, obj):
raise ValueError("StreamingDataLoader.load_state_dict called! Hooray!")
class MyLightningDataModule(LightningDataModule):
def prepare_data(self):
cache = Cache("temp/", chunk_size=1)
dset_len = 10
for i in range(dset_len):
cache[i] = i
cache.done()
cache.merge()
def setup(self, stage):
self.dset = StreamingDataset("temp/")
def train_dataloader(self):
return MyStreamingDataLoader(self.dset)
class MyLightningModule(LightningModule):
def __init__(self):
super().__init__()
self.param = nn.Parameter(torch.randn(1))
def training_step(self, batch):
if self.current_epoch == 2:
raise MyTrainingException
return batch * self.param
def configure_optimizers(self):
return torch.optim.Adam(self.parameters())
dm = MyLightningDataModule()
model = MyLightningModule()
trainer = Trainer(max_epochs=10, default_root_dir="temp/")
try:
trainer.fit(model, dm)
except MyTrainingException:
print("Training crashed as expected on epoch 2.")
# resume training
dm = MyLightningDataModule()
model = MyLightningModule()
trainer = Trainer(max_epochs=10, default_root_dir="temp/")
# the call below should raise ValueError("StreamingDataLoader.load_state_dict called! Hooray!")
# but it raises MyTrainingException again, which means the dataloader state is not loaded from the checkpoint
trainer.fit(model, dm, ckpt_path="temp/lightning_logs/version_0/checkpoints/epoch=1-step=20.ckpt")bhimrazy
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't workinghelp wantedExtra attention is neededExtra attention is needed