-
Notifications
You must be signed in to change notification settings - Fork 3k
Description
Describe the bug
I ran into this issue when processing a file (900MB) with just one example but simplified for a quicker reproducer below. The problem is that, if num_shards is not explicitly set, we calculate it manually using https://github.com/huggingface/datasets/blob/main/src/datasets/utils/py_utils.py#L96 with the default config.MAX_SHARD_SIZE which is 500MB. If a single example is now larger than this, we run into an index error since the shards should be processed individually.
An easy workaround is:
dataset.save_to_disk(output_path, max_shard_size="1GB") or dataset.save_to_disk(output_path, num_shards=1).
I believe this should be fixed and can happen in edge cases for image data, especially when just testing single partitions. The fix would be rather easy, just using a num_shards = min(num_examples, <previously_calculated_num_shards>)
Steps to reproduce the bug
from datasets import Dataset
target_size = 2 * 1024 * 1024 # 2 MB in bytes
base_text = (
"This is a sample sentence that will be repeated many times to create a large dataset. "
* 100
)
large_text = ""
while len(large_text.encode("utf-8")) < target_size:
large_text += base_text
actual_size = len(large_text.encode("utf-8"))
size_mb = actual_size / (1024 * 1024)
data = {"text": [large_text], "label": [0], "id": [1]}
dataset = Dataset.from_dict(data)
output_path = "./sample_dataset"
# make sure this is split into 2 shards
dataset.save_to_disk(output_path, max_shard_size="1MB")this results in
```bash
Saving the dataset (1/3 shards): 100%|████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 162.96 examples/s]
Traceback (most recent call last):
File "/home/tpitters/programming/toy-mmu/create_dataset.py", line 27, in <module>
dataset.save_to_disk(output_path, max_shard_size="1MB")
~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/tpitters/programming/toy-mmu/.venv/lib/python3.13/site-packages/datasets/arrow_dataset.py", line 1640, in save_to_disk
for kwargs in kwargs_per_job:
^^^^^^^^^^^^^^
File "/home/tpitters/programming/toy-mmu/.venv/lib/python3.13/site-packages/datasets/arrow_dataset.py", line 1617, in <genexpr>
"shard": self.shard(num_shards=num_shards, index=shard_idx, contiguous=True),
~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/tpitters/programming/toy-mmu/.venv/lib/python3.13/site-packages/datasets/arrow_dataset.py", line 4987, in shard
return self.select(
~~~~~~~~~~~^
indices=indices,
^^^^^^^^^^^^^^^^
...<2 lines>...
writer_batch_size=writer_batch_size,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
)
^
File "/home/tpitters/programming/toy-mmu/.venv/lib/python3.13/site-packages/datasets/arrow_dataset.py", line 562, in wrapper
out: Union["Dataset", "DatasetDict"] = func(self, *args, **kwargs)
~~~~^^^^^^^^^^^^^^^^^^^^^^^
File "/home/tpitters/programming/toy-mmu/.venv/lib/python3.13/site-packages/datasets/fingerprint.py", line 442, in wrapper
out = func(dataset, *args, **kwargs)
File "/home/tpitters/programming/toy-mmu/.venv/lib/python3.13/site-packages/datasets/arrow_dataset.py", line 4104, in select
return self._select_contiguous(start, length, new_fingerprint=new_fingerprint)
~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/tpitters/programming/toy-mmu/.venv/lib/python3.13/site-packages/datasets/arrow_dataset.py", line 562, in wrapper
out: Union["Dataset", "DatasetDict"] = func(self, *args, **kwargs)
~~~~^^^^^^^^^^^^^^^^^^^^^^^
File "/home/tpitters/programming/toy-mmu/.venv/lib/python3.13/site-packages/datasets/fingerprint.py", line 442, in wrapper
out = func(dataset, *args, **kwargs)
File "/home/tpitters/programming/toy-mmu/.venv/lib/python3.13/site-packages/datasets/arrow_dataset.py", line 4164, in _select_contiguous
_check_valid_indices_value(start, len(self))
~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^
File "/home/tpitters/programming/toy-mmu/.venv/lib/python3.13/site-packages/datasets/arrow_dataset.py", line 624, in _check_valid_indices_value
raise IndexError(f"Index {index} out of range for dataset of size {size}.")
IndexError: Index 1 out of range for dataset of size 1.
Expected behavior
should pass
Environment info
datasets==4.4.2