@@ -209,7 +209,8 @@ def _load_dataset_path(
209209 kwargs = {'split' : 'train' , 'streaming' : streaming , 'num_proc' : num_proc }
210210 if file_type == 'csv' :
211211 kwargs ['na_filter' ] = False
212- dataset = hf_load_dataset (file_type , data_files = dataset_path , ** kwargs )
212+ with safe_ddp_context (None , True ):
213+ dataset = hf_load_dataset (file_type , data_files = dataset_path , ** kwargs )
213214 if columns :
214215 dataset = RowPreprocessor .safe_rename_columns (dataset , columns )
215216 dataset = dataset_meta .preprocess_func (
@@ -315,7 +316,8 @@ def _select_subsets(subsets: List[str], dataset_meta: DatasetMeta) -> List[Subse
315316 @staticmethod
316317 def shuffle_dataset (dataset , seed : int , buffer_size : int = 1000 ):
317318 if isinstance (dataset , HfDataset ):
318- return dataset .shuffle (seed = seed )
319+ with safe_ddp_context (None , True ):
320+ return dataset .shuffle (seed = seed )
319321 else :
320322 return dataset .shuffle (seed = seed , buffer_size = buffer_size )
321323
@@ -366,8 +368,9 @@ def post_process(
366368 val_sample = max (int (train_len * split_dataset_ratio ), 1 )
367369 train_sample = dataset_sample - val_sample
368370 assert train_sample > 0
369- train_dataset , val_dataset = train_dataset .train_test_split (
370- test_size = val_sample , shuffle = shuffle , seed = get_seed (random_state )).values ()
371+ with safe_ddp_context (None , True ):
372+ train_dataset , val_dataset = train_dataset .train_test_split (
373+ test_size = val_sample , shuffle = shuffle , seed = get_seed (random_state )).values ()
371374 train_dataset = sample_dataset (train_dataset , train_sample , shuffle , random_state )
372375 return train_dataset , val_dataset
373376
0 commit comments