Skip to content

Commit ade5624

Browse files
authored
[dataset] fix dataset ddp write conflict (#4860)
1 parent 912cc0c commit ade5624

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

swift/llm/dataset/loader.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,8 @@ def _load_dataset_path(
209209
kwargs = {'split': 'train', 'streaming': streaming, 'num_proc': num_proc}
210210
if file_type == 'csv':
211211
kwargs['na_filter'] = False
212-
dataset = hf_load_dataset(file_type, data_files=dataset_path, **kwargs)
212+
with safe_ddp_context(None, True):
213+
dataset = hf_load_dataset(file_type, data_files=dataset_path, **kwargs)
213214
if columns:
214215
dataset = RowPreprocessor.safe_rename_columns(dataset, columns)
215216
dataset = dataset_meta.preprocess_func(
@@ -315,7 +316,8 @@ def _select_subsets(subsets: List[str], dataset_meta: DatasetMeta) -> List[Subse
315316
@staticmethod
316317
def shuffle_dataset(dataset, seed: int, buffer_size: int = 1000):
317318
if isinstance(dataset, HfDataset):
318-
return dataset.shuffle(seed=seed)
319+
with safe_ddp_context(None, True):
320+
return dataset.shuffle(seed=seed)
319321
else:
320322
return dataset.shuffle(seed=seed, buffer_size=buffer_size)
321323

@@ -366,8 +368,9 @@ def post_process(
366368
val_sample = max(int(train_len * split_dataset_ratio), 1)
367369
train_sample = dataset_sample - val_sample
368370
assert train_sample > 0
369-
train_dataset, val_dataset = train_dataset.train_test_split(
370-
test_size=val_sample, shuffle=shuffle, seed=get_seed(random_state)).values()
371+
with safe_ddp_context(None, True):
372+
train_dataset, val_dataset = train_dataset.train_test_split(
373+
test_size=val_sample, shuffle=shuffle, seed=get_seed(random_state)).values()
371374
train_dataset = sample_dataset(train_dataset, train_sample, shuffle, random_state)
372375
return train_dataset, val_dataset
373376

0 commit comments

Comments
 (0)