@@ -209,7 +209,8 @@ def _load_dataset_path(
209
209
kwargs = {'split' : 'train' , 'streaming' : streaming , 'num_proc' : num_proc }
210
210
if file_type == 'csv' :
211
211
kwargs ['na_filter' ] = False
212
- dataset = hf_load_dataset (file_type , data_files = dataset_path , ** kwargs )
212
+ with safe_ddp_context (None , True ):
213
+ dataset = hf_load_dataset (file_type , data_files = dataset_path , ** kwargs )
213
214
if columns :
214
215
dataset = RowPreprocessor .safe_rename_columns (dataset , columns )
215
216
dataset = dataset_meta .preprocess_func (
@@ -315,7 +316,8 @@ def _select_subsets(subsets: List[str], dataset_meta: DatasetMeta) -> List[Subse
315
316
@staticmethod
316
317
def shuffle_dataset (dataset , seed : int , buffer_size : int = 1000 ):
317
318
if isinstance (dataset , HfDataset ):
318
- return dataset .shuffle (seed = seed )
319
+ with safe_ddp_context (None , True ):
320
+ return dataset .shuffle (seed = seed )
319
321
else :
320
322
return dataset .shuffle (seed = seed , buffer_size = buffer_size )
321
323
@@ -366,8 +368,9 @@ def post_process(
366
368
val_sample = max (int (train_len * split_dataset_ratio ), 1 )
367
369
train_sample = dataset_sample - val_sample
368
370
assert train_sample > 0
369
- train_dataset , val_dataset = train_dataset .train_test_split (
370
- test_size = val_sample , shuffle = shuffle , seed = get_seed (random_state )).values ()
371
+ with safe_ddp_context (None , True ):
372
+ train_dataset , val_dataset = train_dataset .train_test_split (
373
+ test_size = val_sample , shuffle = shuffle , seed = get_seed (random_state )).values ()
371
374
train_dataset = sample_dataset (train_dataset , train_sample , shuffle , random_state )
372
375
return train_dataset , val_dataset
373
376
0 commit comments