Skip to content

Commit 523ba81

Browse files
committed
again
1 parent 097c0fe commit 523ba81

File tree

2 files changed

+7
-9
lines changed

2 files changed

+7
-9
lines changed

src/datasets/iterable_dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@
8383

8484
logger = get_logger(__name__)
8585

86-
Key = Union[int, str]
86+
Key = Union[int, str, tuple[int, int]]
8787

8888

8989
def identity_func(x):

tests/test_builder.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -163,11 +163,9 @@ def _split_generators(self, dl_manager):
163163
return [SplitGenerator(name=Split.TRAIN, gen_kwargs={"filepaths": [f"data{i}.txt" for i in range(4)]})]
164164

165165
def _generate_examples(self, filepaths):
166-
idx = 0
167-
for filepath in filepaths:
166+
for shard_idx, filepath in enumerate(filepaths):
168167
for i in range(100):
169-
yield idx, {"id": i, "filepath": filepath}
170-
idx += 1
168+
yield (shard_idx, i), {"id": i, "filepath": filepath}
171169

172170

173171
class DummyArrowBasedBuilderWithAmbiguousShards(ArrowBasedBuilder):
@@ -766,16 +764,16 @@ def test_config_names(self):
766764

767765
def test_cache_dir_for_data_dir(self):
768766
with tempfile.TemporaryDirectory() as tmp_dir, tempfile.TemporaryDirectory() as data_dir:
769-
builder = DummyBuilderWithManualDownload(cache_dir=tmp_dir, config_name="a", data_dir=data_dir)
770-
other_builder = DummyBuilderWithManualDownload(cache_dir=tmp_dir, config_name="a", data_dir=data_dir)
767+
builder = DummyGeneratorBasedBuilder(cache_dir=tmp_dir, config_name="a", data_dir=data_dir)
768+
other_builder = DummyGeneratorBasedBuilder(cache_dir=tmp_dir, config_name="a", data_dir=data_dir)
771769
self.assertEqual(builder.cache_dir, other_builder.cache_dir)
772-
other_builder = DummyBuilderWithManualDownload(cache_dir=tmp_dir, config_name="a", data_dir=tmp_dir)
770+
other_builder = DummyGeneratorBasedBuilder(cache_dir=tmp_dir, config_name="a", data_dir=tmp_dir)
773771
self.assertNotEqual(builder.cache_dir, other_builder.cache_dir)
774772

775773
def test_cache_dir_for_configured_builder(self):
776774
with tempfile.TemporaryDirectory() as tmp_dir, tempfile.TemporaryDirectory() as data_dir:
777775
builder_cls = configure_builder_class(
778-
DummyBuilderWithManualDownload,
776+
DummyGeneratorBasedBuilder,
779777
builder_configs=[BuilderConfig(data_dir=data_dir)],
780778
default_config_name=None,
781779
dataset_name="dummy",

0 commit comments

Comments
 (0)