Skip to content

Commit e8708ed

Browse files
authored
[Trainer] Fix distributed dataloader (#8932)
* fix ddloader, fix uc unittest * update dataloader
1 parent 277fdb4 commit e8708ed

File tree

3 files changed

+92
-93
lines changed

3 files changed

+92
-93
lines changed

paddlenlp/data/dist_dataloader.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,11 @@ def __len__(self):
3333
return 0
3434

3535

36+
class IterableDummyDataset(paddle.io.IterableDataset):
37+
def __iter__(self):
38+
return None
39+
40+
3641
class DistDataLoader(paddle.io.DataLoader):
3742
"""
3843
DistDataLoader is a wrapper of paddle.io.DataLoader.
@@ -56,11 +61,14 @@ def __init__(
5661
timeout=0,
5762
worker_init_fn=None,
5863
persistent_workers=False,
59-
eval=False,
64+
**kwargs,
6065
):
6166

67+
eval = kwargs.pop("eval", False)
68+
is_iterable_dataset = kwargs.pop("is_iterable_dataset", False)
69+
6270
if dataset is None:
63-
dataset = DummyDataset()
71+
dataset = DummyDataset() if not is_iterable_dataset else IterableDummyDataset()
6472
logger.info("rank has no data, use Dummpy dataset")
6573

6674
super().__init__(dataset=dataset, batch_sampler=batch_sampler, collate_fn=collate_fn, num_workers=num_workers)
@@ -200,7 +208,7 @@ def __next__(self):
200208
try:
201209
data = next(self._dataloader_iter)
202210
data = nested_copy_place(data, place=paddle.framework._current_expected_place())
203-
except:
204-
pass
211+
except Exception as e:
212+
logger.debug(e)
205213
data = self._broadcast_data(data)
206214
return data

paddlenlp/trainer/trainer.py

Lines changed: 78 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -1404,12 +1404,16 @@ def get_train_dataloader(self):
14041404
raise ValueError("We don't need train_dataset when should_load_dataset is False.")
14051405

14061406
train_dataset = self.train_dataset
1407+
if self.args.distributed_dataloader:
1408+
is_iterable_dataset = self._is_iterable_dataset_distributed(train_dataset)
1409+
else:
1410+
is_iterable_dataset = self._is_iterable_dataset(train_dataset)
14071411
if is_datasets_available() and train_dataset is not None and isinstance(train_dataset, datasets.Dataset):
14081412
train_dataset = self._remove_unused_columns(train_dataset, description="training")
14091413
_DataLoader = DistDataLoader if self.args.distributed_dataloader else DataLoader
14101414

1411-
if self._is_iterable_dataset(train_dataset):
1412-
if self.args.dataset_world_size > 1:
1415+
if is_iterable_dataset: # For iterable dataset
1416+
if self.args.dataset_world_size > 1 and train_dataset is not None:
14131417
train_dataset = IterableDatasetShard(
14141418
train_dataset,
14151419
batch_size=self.args.per_device_train_batch_size,
@@ -1418,24 +1422,28 @@ def get_train_dataloader(self):
14181422
process_index=self.args.dataset_rank,
14191423
)
14201424

1425+
if self.args.distributed_dataloader:
1426+
logger.info("Training using DistDataLoader.")
1427+
additional_configs = {"is_iterable_dataset": True}
1428+
else:
1429+
additional_configs = {}
14211430
return _DataLoader(
14221431
train_dataset,
14231432
batch_size=self.args.per_device_train_batch_size,
14241433
collate_fn=self.data_collator,
14251434
num_workers=self.args.dataloader_num_workers,
1435+
**additional_configs,
1436+
)
1437+
else:
1438+
train_sampler = self._get_train_sampler()
1439+
if self.args.distributed_dataloader:
1440+
logger.info("Training using DistDataLoader.")
1441+
return _DataLoader(
1442+
train_dataset,
1443+
batch_sampler=train_sampler,
1444+
collate_fn=self.data_collator,
1445+
num_workers=self.args.dataloader_num_workers,
14261446
)
1427-
1428-
train_sampler = self._get_train_sampler()
1429-
1430-
if self.args.distributed_dataloader:
1431-
logger.info("Training using DistDataLoader.")
1432-
1433-
return _DataLoader(
1434-
train_dataset,
1435-
batch_sampler=train_sampler,
1436-
collate_fn=self.data_collator,
1437-
num_workers=self.args.dataloader_num_workers,
1438-
)
14391447

14401448
def _get_eval_sampler(self, eval_dataset: Dataset):
14411449
if eval_dataset is None or not has_length(eval_dataset):
@@ -1482,54 +1490,48 @@ def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoa
14821490
raise ValueError("We don't need eval_dataset when should_load_dataset is False.")
14831491

14841492
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
1485-
1493+
if self.args.distributed_dataloader:
1494+
is_iterable_dataset = self._is_iterable_dataset_distributed(eval_dataset)
1495+
else:
1496+
is_iterable_dataset = self._is_iterable_dataset(eval_dataset)
14861497
if is_datasets_available() and eval_dataset is not None and isinstance(eval_dataset, datasets.Dataset):
14871498
eval_dataset = self._remove_unused_columns(eval_dataset, description="evaluation")
1499+
_DataLoader = DistDataLoader if self.args.distributed_dataloader else DataLoader
14881500

1489-
if self._is_iterable_dataset(eval_dataset):
1490-
if self.args.dataset_world_size > 1:
1501+
if is_iterable_dataset:
1502+
if self.args.dataset_world_size > 1 and eval_dataset is not None:
14911503
eval_dataset = IterableDatasetShard(
14921504
eval_dataset,
14931505
batch_size=self.args.per_device_eval_batch_size,
14941506
drop_last=self.args.dataloader_drop_last,
14951507
num_processes=self.args.dataset_world_size,
14961508
process_index=self.args.dataset_rank,
14971509
)
1498-
14991510
if self.args.distributed_dataloader:
1500-
return DistDataLoader(
1501-
eval_dataset,
1502-
batch_size=self.args.per_device_eval_batch_size,
1503-
collate_fn=self.data_collator,
1504-
num_workers=0,
1505-
eval=True,
1506-
)
1511+
logger.info("Eval using DistDataLoader.")
1512+
additional_configs = {"eval": True, "is_iterable_dataset": True}
15071513
else:
1508-
return DataLoader(
1509-
eval_dataset,
1510-
batch_size=self.args.per_device_eval_batch_size,
1511-
collate_fn=self.data_collator,
1512-
num_workers=0,
1513-
)
1514-
1515-
eval_sampler = self._get_eval_sampler(eval_dataset)
1516-
1517-
if self.args.distributed_dataloader:
1518-
logger.info("Eval using DistDataLoader.")
1519-
1520-
return DistDataLoader(
1514+
additional_configs = {}
1515+
return _DataLoader(
15211516
eval_dataset,
1522-
batch_sampler=eval_sampler,
1517+
batch_size=self.args.per_device_eval_batch_size,
15231518
collate_fn=self.data_collator,
1524-
num_workers=self.args.dataloader_num_workers,
1525-
eval=True,
1519+
num_workers=0,
1520+
**additional_configs,
15261521
)
15271522
else:
1528-
return DataLoader(
1523+
eval_sampler = self._get_eval_sampler(eval_dataset)
1524+
if self.args.distributed_dataloader:
1525+
logger.info("Eval using DistDataLoader.")
1526+
additional_configs = {"eval": True}
1527+
else:
1528+
additional_configs = {}
1529+
return _DataLoader(
15291530
eval_dataset,
15301531
batch_sampler=eval_sampler,
15311532
collate_fn=self.data_collator,
15321533
num_workers=self.args.dataloader_num_workers,
1534+
**additional_configs,
15331535
)
15341536

15351537
def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
@@ -1548,11 +1550,16 @@ def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
15481550
if not self.args.should_load_dataset and test_dataset is not None:
15491551
raise ValueError("We don't need test_dataset when should_load_dataset is False.")
15501552

1553+
if self.args.distributed_dataloader:
1554+
is_iterable_dataset = self._is_iterable_dataset_distributed(test_dataset)
1555+
else:
1556+
is_iterable_dataset = self._is_iterable_dataset(test_dataset)
15511557
if is_datasets_available() and test_dataset is not None and isinstance(test_dataset, datasets.Dataset):
15521558
test_dataset = self._remove_unused_columns(test_dataset, description="test")
1559+
_DataLoader = DistDataLoader if self.args.distributed_dataloader else DataLoader
15531560

1554-
if self._is_iterable_dataset(test_dataset):
1555-
if self.args.dataset_world_size > 1:
1561+
if is_iterable_dataset:
1562+
if self.args.dataset_world_size > 1 and test_dataset is not None:
15561563
test_dataset = IterableDatasetShard(
15571564
test_dataset,
15581565
batch_size=self.args.per_device_eval_batch_size,
@@ -1562,40 +1569,31 @@ def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
15621569
)
15631570

15641571
if self.args.distributed_dataloader:
1565-
return DistDataLoader(
1566-
test_dataset,
1567-
batch_size=self.args.per_device_eval_batch_size * self.world_size,
1568-
collate_fn=self.data_collator, # _get_collator_with_removed_columns
1569-
num_workers=self.args.dataloader_num_workers,
1570-
eval=True,
1571-
)
1572+
logger.info("Test using DistDataLoader.")
1573+
additional_config = {"eval": True, "is_iterable_dataset": True}
15721574
else:
1573-
return DataLoader(
1574-
test_dataset,
1575-
batch_size=self.args.per_device_eval_batch_size * self.world_size,
1576-
collate_fn=self.data_collator, # _get_collator_with_removed_columns
1577-
num_workers=self.args.dataloader_num_workers,
1578-
)
1579-
1580-
test_sampler = self._get_eval_sampler(test_dataset)
1581-
1582-
if self.args.distributed_dataloader:
1583-
logger.info("Test using DistDataLoader.")
1584-
1585-
# We use the same batch_size as for eval.
1586-
return DistDataLoader(
1575+
additional_config = {}
1576+
return _DataLoader(
15871577
test_dataset,
1588-
batch_sampler=test_sampler,
1578+
batch_size=self.args.per_device_eval_batch_size * self.world_size,
15891579
collate_fn=self.data_collator,
1590-
drop_last=self.args.dataloader_drop_last,
1591-
eval=True,
1580+
num_workers=self.args.dataloader_num_workers,
1581+
**additional_config,
15921582
)
15931583
else:
1594-
return DataLoader(
1584+
test_sampler = self._get_eval_sampler(test_dataset)
1585+
if self.args.distributed_dataloader:
1586+
logger.info("Test using DistDataLoader.")
1587+
additional_config = {"eval": True}
1588+
else:
1589+
additional_config = {}
1590+
# We use the same batch_size as for eval.
1591+
return _DataLoader(
15951592
test_dataset,
15961593
batch_sampler=test_sampler,
15971594
collate_fn=self.data_collator,
15981595
drop_last=self.args.dataloader_drop_last,
1596+
**additional_config,
15991597
)
16001598

16011599
def create_optimizer_and_scheduler(self, num_training_steps: int):
@@ -1700,6 +1698,8 @@ def _load_rng_state(self, checkpoint):
17001698

17011699
if self.args.use_hybrid_parallel:
17021700
if "hybrid_parallel_rng_state_tracker" in checkpoint_rng_state:
1701+
if self.args.tensor_parallel_degree <= 1:
1702+
checkpoint_rng_state["hybrid_parallel_rng_state_tracker"].pop("model_parallel_rng", None)
17031703
fleet.meta_parallel.get_rng_state_tracker().set_states_tracker(
17041704
checkpoint_rng_state["hybrid_parallel_rng_state_tracker"]
17051705
)
@@ -3210,6 +3210,15 @@ def _get_collator_with_removed_columns(
32103210
def _is_iterable_dataset(self, dataset):
32113211
return isinstance(dataset, paddle.io.IterableDataset)
32123212

3213+
def _is_iterable_dataset_distributed(self, dataset):
3214+
# For distributed dataloaer.
3215+
is_iterable_dataset_tensor = paddle.to_tensor(self._is_iterable_dataset(dataset)).reshape([1])
3216+
if dist.get_world_size() > 1:
3217+
dist.all_reduce(is_iterable_dataset_tensor, op=dist.ReduceOp.MAX)
3218+
if is_iterable_dataset_tensor.item() == 1:
3219+
return True
3220+
return False
3221+
32133222
def print_config(self, args=None, key=""):
32143223
"""
32153224
print config values

tests/trainer/test_unified_checkpoint.py

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -659,7 +659,7 @@ def setUp(self):
659659
self.need_allclose = True
660660
self.rtol = 1e-7
661661

662-
self.run_pretrain_file = "llm/llama/run_pretrain.py"
662+
self.run_pretrain_file = "llm/run_pretrain.py"
663663

664664
def runfirst(self, train_args):
665665
train_args["unified_checkpoint"] = 0
@@ -701,7 +701,7 @@ def setUp(self):
701701
self.need_allclose = True
702702
self.rtol = 1e-7
703703

704-
self.run_pretrain_file = "llm/llama/run_pretrain.py"
704+
self.run_pretrain_file = "llm/run_pretrain.py"
705705
self.filelists = [
706706
"config.json",
707707
"master_weights-00001-of-00002.safetensors",
@@ -1132,24 +1132,6 @@ def rerun(self, train_args):
11321132
np.testing.assert_allclose(res[0], res[-1], rtol=self.rtol)
11331133

11341134

1135-
@pytest.mark.skipif(True, reason="Skip for None CE")
1136-
class TestUnifiedCheckpointOnN1C8EnableAll(TestUnifiedCheckpointBase):
1137-
def setUp(self):
1138-
super().setUp()
1139-
for config_key in self.configs:
1140-
self.configs[config_key]["unified_checkpoint"] = 1
1141-
self.configs[config_key]["unified_checkpoint_config"] = "enable_all_options"
1142-
1143-
self.need_allclose = True
1144-
self.rtol = 1e-7
1145-
1146-
def runfirst(self, train_args):
1147-
self.run_n1c8(self.run_pretrain_file, **train_args)
1148-
1149-
def rerun(self, train_args):
1150-
self.run_n1c8(self.run_pretrain_file, **train_args)
1151-
1152-
11531135
@pytest.mark.skipif(True, reason="Skip for None CE")
11541136
class TestUnifiedCheckpointOnN1C8SaveLoadSpeed(TestUnifiedCheckpointFull):
11551137
def setUp(self):

0 commit comments

Comments
 (0)