Skip to content

Resuming Training with New Dataset Fails #263

@schopra8

Description

@schopra8

🐛 Bug

If you train a model with a particular dataset for N epochs and then want to continue training with a new dataset, LitData throws an exception.

To Reproduce

Steps to reproduce the behavior:

  1. Train a model with dataset-1
  2. Cancel training after the first checkpoint is aved
  3. Resume training with trainer.fit(model, datamodule=datamodule, ckpt_path=ckpt_path) where the datamodule now points to dataset-2.
  4. Capture the following error
[rank7]: Original Traceback (most recent call last):
[rank7]:   File "/home/sahil/.cache/pypoetry/virtualenvs/auw7Hy33-py3.10/lib/python3.10/site-packages/torch/utils/data/_utils/worker.py", line 252, in _worker_loop
[rank7]:     fetcher = _DatasetKind.create_fetcher(dataset_kind, dataset, auto_collation, collate_fn, drop_last)
[rank7]:   File "/home/sahil/.cache/pypoetry/virtualenvs/auw7Hy33-py3.10/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 79, in create_fetcher
[rank7]:     return _utils.fetch._IterableDatasetFetcher(dataset, auto_collation, collate_fn, drop_last)
[rank7]:   File "/home/sahil/.cache/pypoetry/virtualenvs/auw7Hy33-py3.10/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 21, in __init__
[rank7]:     self.dataset_iter = iter(dataset)
[rank7]:   File "/home/sahil/.cache/pypoetry/virtualenvs/auw7Hy33-py3.10/lib/python3.10/site-packages/litdata/streaming/combined.py", line 155, in __iter__
[rank7]:     self._iterator = _CombinedDatasetIterator(
[rank7]:   File "/home/sahil/.cache/pypoetry/virtualenvs/auw7Hy33-py3.10/lib/python3.10/site-packages/litdata/streaming/combined.py", line 203, in __init__
[rank7]:     self._dataset_iters = [iter(dataset) for dataset in datasets]
[rank7]:   File "/home/sahil/.cache/pypoetry/virtualenvs/auw7Hy33-py3.10/lib/python3.10/site-packages/litdata/streaming/combined.py", line 203, in <listcomp>
[rank7]:     self._dataset_iters = [iter(dataset) for dataset in datasets]
[rank7]:   File "/home/sahil/.cache/pypoetry/virtualenvs/auw7Hy33-py3.10/lib/python3.10/site-packages/litdata/streaming/dataset.py", line 219, in __iter__
[rank7]:     self._validate_state_dict()
[rank7]:   File "/home/sahil/.cache/pypoetry/virtualenvs/auw7Hy33-py3.10/lib/python3.10/site-packages/litdata/streaming/dataset.py", line 447, in _validate_state_dict
[rank7]:     raise ValueError(
[rank7]: ValueError: The provided input_dir URL state doesn't match the current one. Found s3://dataset-2 instead of s3://dataset-1.

Code sample

Expected behavior

Training to start with the optimizer states, model weights, etc. but with a net new dataset.

Environment

  • PyTorch Version (e.g., 1.0): 2.3.1
  • OS (e.g., Linux): Linux
  • How you installed PyTorch (conda, pip, source): pip
  • Build command you used (if compiling from source):
  • Python version: 3.10
  • CUDA/cuDNN version: 12.1
  • GPU models and configuration: 2x8H100
  • Any other relevant information:

Additional context

Metadata

Metadata

Assignees

Labels

bugSomething isn't workinghelp wantedExtra attention is needed

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions