File tree Expand file tree Collapse file tree 1 file changed +11
-4
lines changed Expand file tree Collapse file tree 1 file changed +11
-4
lines changed Original file line number Diff line number Diff line change @@ -332,16 +332,23 @@ def load_tokenized_prepared_datasets(
332
332
if cfg .local_rank == 0 and not cfg .skip_prepare_dataset :
333
333
LOG .info (f"Saving merged prepared dataset to disk... { prepared_ds_path } " )
334
334
if isinstance (dataset , IterableDataset ):
335
+ num_workers = cfg .dataset_processes
335
336
336
- def gen_from_iter_ds (_ds , _ = None ):
337
- yield from _ds
337
+ def gen_from_iter_ds (_ds , worker_id : List [int ], num_workers : List [int ]):
338
+ """Generator function to correctly splice the dataset for each worker"""
339
+ for i , item in enumerate (_ds ):
340
+ if i % num_workers [0 ] == worker_id [0 ]:
341
+ yield item
338
342
339
343
ds_from_iter = Dataset .from_generator (
340
344
functools .partial (gen_from_iter_ds , dataset ),
341
345
features = dataset .features ,
342
- num_proc = cfg . dataset_processes ,
346
+ num_proc = num_workers ,
343
347
split = split ,
344
- gen_kwargs = {"_" : list (range (cfg .dataset_processes ))},
348
+ gen_kwargs = {
349
+ "worker_id" : list (range (num_workers )),
350
+ "num_workers" : [num_workers ] * num_workers ,
351
+ },
345
352
)
346
353
ds_from_iter .save_to_disk (str (prepared_ds_path ))
347
354
else :
You can’t perform that action at this time.
0 commit comments