Skip to content

ShuffledDataSourcesArrowExamplesIterable cannot properly resume from checkpoint #7901

@howitry

Description

@howitry

Describe the bug

ShuffledDataSourcesArrowExamplesIterable cannot properly resume from checkpoint

Steps to reproduce the bug

  1. The reproducible code is as follows:
from datasets import Dataset, concatenate_datasets, interleave_datasets
ds = Dataset.from_dict({"a": range(12)}).to_iterable_dataset(num_shards=1)
ds = ds.shuffle(seed=42)

for idx, example in enumerate(ds):
    print(example)
    if idx == 2: #The checkpoint can be loaded correctly only when idx <= 1.
        state_dict = ds.state_dict()
        print("checkpoint")
        break

print("state_dict: ",state_dict)

ds.load_state_dict(state_dict)
print(f"restart from checkpoint")
for example in ds:
    print(example)
  1. The error message is as follows:
{'a': 0}
{'a': 7}
{'a': 6}
checkpoint
state_dict:  {'examples_iterable': {'examples_iterable': {'examples_iterable': {'shard_idx': 1, 'shard_example_idx': 0, 'type': 'ShuffledDataSourcesArrowExamplesIterable'}, 'previous_state': {'shard_idx': 0, 'shard_example_idx': 0, 'type': 'ShuffledDataSourcesArrowExamplesIterable'}, 'batch_idx': 12, 'num_chunks_since_previous_state': 12, 'cropped_chunk_length': 0, 'type': 'RebatchedArrowExamplesIterable'}, 'previous_state': {'examples_iterable': {'shard_idx': 1, 'shard_example_idx': 0, 'type': 'ShuffledDataSourcesArrowExamplesIterable'}, 'previous_state': {'shard_idx': 0, 'shard_example_idx': 0, 'type': 'ShuffledDataSourcesArrowExamplesIterable'}, 'batch_idx': 12, 'num_chunks_since_previous_state': 12, 'cropped_chunk_length': 0, 'type': 'RebatchedArrowExamplesIterable'}, 'batch_idx': 3, 'num_chunks_since_previous_state': 2, 'cropped_chunk_length': 0, 'type': 'RebatchedArrowExamplesIterable'}, 'epoch': 0}
restart from checkpoint
Loading a state dict of a shuffle buffer of a dataset without the buffer content.The shuffle buffer will be refilled before starting to yield new examples.

Expected behavior

I want a correct resume from any checkpoint, but currently the checkpoint can only be loaded correctly when idx <= 1.

Environment info

datasets Version: 4.4.1

@lhoestq

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions