Skip to content

Commit 32637fa

Browse files
authored
fix: preprocess yielding whole dataset to each worker (axolotl-ai-cloud#2503) [skip ci]
1 parent f776f88 commit 32637fa

File tree

1 file changed

+11
-4
lines changed

1 file changed

+11
-4
lines changed

src/axolotl/utils/data/sft.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -332,16 +332,23 @@ def load_tokenized_prepared_datasets(
332332
if cfg.local_rank == 0 and not cfg.skip_prepare_dataset:
333333
LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}")
334334
if isinstance(dataset, IterableDataset):
335+
num_workers = cfg.dataset_processes
335336

336-
def gen_from_iter_ds(_ds, _=None):
337-
yield from _ds
337+
def gen_from_iter_ds(_ds, worker_id: List[int], num_workers: List[int]):
338+
"""Generator function to correctly splice the dataset for each worker"""
339+
for i, item in enumerate(_ds):
340+
if i % num_workers[0] == worker_id[0]:
341+
yield item
338342

339343
ds_from_iter = Dataset.from_generator(
340344
functools.partial(gen_from_iter_ds, dataset),
341345
features=dataset.features,
342-
num_proc=cfg.dataset_processes,
346+
num_proc=num_workers,
343347
split=split,
344-
gen_kwargs={"_": list(range(cfg.dataset_processes))},
348+
gen_kwargs={
349+
"worker_id": list(range(num_workers)),
350+
"num_workers": [num_workers] * num_workers,
351+
},
345352
)
346353
ds_from_iter.save_to_disk(str(prepared_ds_path))
347354
else:

0 commit comments

Comments
 (0)