Skip to content

Split dataset by node: index error when sharding iterable datasetΒ #5881

@sanchit-gandhi

Description

@sanchit-gandhi

Describe the bug

Context: we're splitting an iterable dataset by node and then passing it to a torch data loader with multiple workers

When we iterate over it for 5 steps, we don't get an error

When we instead iterate over it for 8 steps, we get an IndexError when fetching the data if we have too many workers

Steps to reproduce the bug

Here, we have 2 JAX processes (jax.process_count() = 2) which we split the dataset over. The dataset loading script can be found here: https://huggingface.co/datasets/distil-whisper/librispeech_asr/blob/c6a1e805cbfeed5057400ac5937327d7e30281b8/librispeech_asr.py#L310

Code to reproduce
from datasets import load_dataset
import jax
from datasets.distributed import split_dataset_by_node
from torch.utils.data import DataLoader
from tqdm import tqdm

# load an example dataset (https://huggingface.co/datasets/distil-whisper/librispeech_asr)
dataset = load_dataset("distil-whisper/librispeech_asr", "all", split="train.clean.100", streaming=True)
# just keep the text column -> no need to define a collator
dataset_text = dataset.remove_columns(set(dataset.features.keys()) - {"text"})

# define some constants
batch_size = 256
num_examples = 5  # works for 5 examples, doesn't for 8
num_workers = dataset_text.n_shards

# try with multiple workers
dataloader = DataLoader(dataset_text, batch_size=batch_size, num_workers=num_workers, drop_last=True)

for i, batch in tqdm(enumerate(dataloader), total=num_examples, desc="Multiple workers"):
    if i == num_examples:
        break

# try splitting by node (we can't do this with `dataset_text` since `split_dataset_by_node` expects the Audio column for an ASR dataset)
dataset = split_dataset_by_node(dataset, rank=jax.process_index(), world_size=jax.process_count())
# remove the text column again
dataset_text = dataset.remove_columns(set(dataset.features.keys()) - {"text"})
dataloader = DataLoader(dataset_text, batch_size=16, num_workers=num_workers // 2, drop_last=True)

for i, batch in tqdm(enumerate(dataloader), total=num_examples, desc="Split by node"):
    if i == num_examples:
        break

# too many workers
dataloader = DataLoader(dataset_text, batch_size=256, num_workers=num_workers, drop_last=True)
for i, batch in tqdm(enumerate(dataloader), total=num_examples, desc="Too many workers"):
    if i == num_examples:
        break
With 5 examples:
Multiple workers: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 5/5 [00:16<00:00,  3.33s/it]
Assigning 7 shards (or data sources) of the dataset to each node.                                                       
Split by node: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 5/5 [00:13<00:00,  2.76s/it]
Assigning 7 shards (or data sources) of the dataset to each node.                                                       
Too many dataloader workers: 14 (max is dataset.n_shards=7). Stopping 7 dataloader workers.                             
To parallelize data loading, we give each process some shards (or data sources) to process. Therefore it's unnecessary t
o have a number of workers greater than dataset.n_shards=7. To enable more parallelism, please split the dataset in more
 files than 7.                                                                                                          
Too many workers: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 5/5 [00:15<00:00,  3.03s/it]
With 7 examples:
Multiple workers: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 8/8 [00:13<00:00,  1.71s/it]
Assigning 7 shards (or data sources) of the dataset to each node.
Split by node: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 8/8 [00:11<00:00,  1.38s/it]
Assigning 7 shards (or data sources) of the dataset to each node.
Too many dataloader workers: 14 (max is dataset.n_shards=7). Stopping 7 dataloader workers.
To parallelize data loading, we give each process some shards (or data sources) to process. Therefore it's unnecessary to have a number of workers greater than dataset.n_shards=7. To enable more parallelism, please split the dataset in more files than 7.
Too many workers:  88%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‹        | 7/8 [00:13<00:01,  1.89s/it]
Traceback (most recent call last):
  File "distil-whisper/test_librispeech.py", line 36, in <module>
    for i, batch in tqdm(enumerate(dataloader), total=num_examples, desc="Too many workers"):
  File "/home/sanchitgandhi/hf/lib/python3.8/site-packages/tqdm/std.py", line 1178, in __iter__
    for obj in iterable:
  File "/home/sanchitgandhi/hf/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 633, in __next__
    data = self._next_data()
  File "/home/sanchitgandhi/hf/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1325, in _next_data
    return self._process_data(data)
  File "/home/sanchitgandhi/hf/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1371, in _process_data
    data.reraise()
  File "/home/sanchitgandhi/hf/lib/python3.8/site-packages/torch/_utils.py", line 644, in reraise
    raise exception
IndexError: Caught IndexError in DataLoader worker process 7.
Original Traceback (most recent call last):
  File "/home/sanchitgandhi/hf/lib/python3.8/site-packages/torch/utils/data/_utils/worker.py", line 308, in _worker_loop
    data = fetcher.fetch(index)
  File "/home/sanchitgandhi/hf/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 32, in fetch
    data.append(next(self.dataset_iter))
  File "/home/sanchitgandhi/datasets/src/datasets/iterable_dataset.py", line 986, in __iter__
    yield from self._iter_pytorch(ex_iterable)
  File "/home/sanchitgandhi/datasets/src/datasets/iterable_dataset.py", line 920, in _iter_pytorch
    for key, example in ex_iterable.shard_data_sources(worker_info.id, worker_info.num_workers):
  File "/home/sanchitgandhi/datasets/src/datasets/iterable_dataset.py", line 540, in shard_data_sources
    self.ex_iterable.shard_data_sources(worker_id, num_workers),
  File "/home/sanchitgandhi/datasets/src/datasets/iterable_dataset.py", line 796, in shard_data_sources
    self.ex_iterable.shard_data_sources(worker_id, num_workers),
  File "/home/sanchitgandhi/datasets/src/datasets/iterable_dataset.py", line 126, in shard_data_sources
    requested_gen_kwargs = _merge_gen_kwargs([gen_kwargs_list[i] for i in shard_indices])
  File "/home/sanchitgandhi/datasets/src/datasets/utils/sharding.py", line 76, in _merge_gen_kwargs
    for key in gen_kwargs_list[0]
IndexError: list index out of range

Expected behavior

Should pass for both 5 and 7 examples

Environment info

  • datasets version: 2.12.1.dev0
  • Platform: Linux-5.13.0-1023-gcp-x86_64-with-glibc2.29
  • Python version: 3.8.10
  • Huggingface_hub version: 0.14.1
  • PyArrow version: 12.0.0
  • Pandas version: 2.0.1

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions