@@ -1404,12 +1404,16 @@ def get_train_dataloader(self):
1404
1404
raise ValueError ("We don't need train_dataset when should_load_dataset is False." )
1405
1405
1406
1406
train_dataset = self .train_dataset
1407
+ if self .args .distributed_dataloader :
1408
+ is_iterable_dataset = self ._is_iterable_dataset_distributed (train_dataset )
1409
+ else :
1410
+ is_iterable_dataset = self ._is_iterable_dataset (train_dataset )
1407
1411
if is_datasets_available () and train_dataset is not None and isinstance (train_dataset , datasets .Dataset ):
1408
1412
train_dataset = self ._remove_unused_columns (train_dataset , description = "training" )
1409
1413
_DataLoader = DistDataLoader if self .args .distributed_dataloader else DataLoader
1410
1414
1411
- if self . _is_iterable_dataset ( train_dataset ):
1412
- if self .args .dataset_world_size > 1 :
1415
+ if is_iterable_dataset : # For iterable dataset
1416
+ if self .args .dataset_world_size > 1 and train_dataset is not None :
1413
1417
train_dataset = IterableDatasetShard (
1414
1418
train_dataset ,
1415
1419
batch_size = self .args .per_device_train_batch_size ,
@@ -1418,24 +1422,28 @@ def get_train_dataloader(self):
1418
1422
process_index = self .args .dataset_rank ,
1419
1423
)
1420
1424
1425
+ if self .args .distributed_dataloader :
1426
+ logger .info ("Training using DistDataLoader." )
1427
+ additional_configs = {"is_iterable_dataset" : True }
1428
+ else :
1429
+ additional_configs = {}
1421
1430
return _DataLoader (
1422
1431
train_dataset ,
1423
1432
batch_size = self .args .per_device_train_batch_size ,
1424
1433
collate_fn = self .data_collator ,
1425
1434
num_workers = self .args .dataloader_num_workers ,
1435
+ ** additional_configs ,
1436
+ )
1437
+ else :
1438
+ train_sampler = self ._get_train_sampler ()
1439
+ if self .args .distributed_dataloader :
1440
+ logger .info ("Training using DistDataLoader." )
1441
+ return _DataLoader (
1442
+ train_dataset ,
1443
+ batch_sampler = train_sampler ,
1444
+ collate_fn = self .data_collator ,
1445
+ num_workers = self .args .dataloader_num_workers ,
1426
1446
)
1427
-
1428
- train_sampler = self ._get_train_sampler ()
1429
-
1430
- if self .args .distributed_dataloader :
1431
- logger .info ("Training using DistDataLoader." )
1432
-
1433
- return _DataLoader (
1434
- train_dataset ,
1435
- batch_sampler = train_sampler ,
1436
- collate_fn = self .data_collator ,
1437
- num_workers = self .args .dataloader_num_workers ,
1438
- )
1439
1447
1440
1448
def _get_eval_sampler (self , eval_dataset : Dataset ):
1441
1449
if eval_dataset is None or not has_length (eval_dataset ):
@@ -1482,54 +1490,48 @@ def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoa
1482
1490
raise ValueError ("We don't need eval_dataset when should_load_dataset is False." )
1483
1491
1484
1492
eval_dataset = eval_dataset if eval_dataset is not None else self .eval_dataset
1485
-
1493
+ if self .args .distributed_dataloader :
1494
+ is_iterable_dataset = self ._is_iterable_dataset_distributed (eval_dataset )
1495
+ else :
1496
+ is_iterable_dataset = self ._is_iterable_dataset (eval_dataset )
1486
1497
if is_datasets_available () and eval_dataset is not None and isinstance (eval_dataset , datasets .Dataset ):
1487
1498
eval_dataset = self ._remove_unused_columns (eval_dataset , description = "evaluation" )
1499
+ _DataLoader = DistDataLoader if self .args .distributed_dataloader else DataLoader
1488
1500
1489
- if self . _is_iterable_dataset ( eval_dataset ) :
1490
- if self .args .dataset_world_size > 1 :
1501
+ if is_iterable_dataset :
1502
+ if self .args .dataset_world_size > 1 and eval_dataset is not None :
1491
1503
eval_dataset = IterableDatasetShard (
1492
1504
eval_dataset ,
1493
1505
batch_size = self .args .per_device_eval_batch_size ,
1494
1506
drop_last = self .args .dataloader_drop_last ,
1495
1507
num_processes = self .args .dataset_world_size ,
1496
1508
process_index = self .args .dataset_rank ,
1497
1509
)
1498
-
1499
1510
if self .args .distributed_dataloader :
1500
- return DistDataLoader (
1501
- eval_dataset ,
1502
- batch_size = self .args .per_device_eval_batch_size ,
1503
- collate_fn = self .data_collator ,
1504
- num_workers = 0 ,
1505
- eval = True ,
1506
- )
1511
+ logger .info ("Eval using DistDataLoader." )
1512
+ additional_configs = {"eval" : True , "is_iterable_dataset" : True }
1507
1513
else :
1508
- return DataLoader (
1509
- eval_dataset ,
1510
- batch_size = self .args .per_device_eval_batch_size ,
1511
- collate_fn = self .data_collator ,
1512
- num_workers = 0 ,
1513
- )
1514
-
1515
- eval_sampler = self ._get_eval_sampler (eval_dataset )
1516
-
1517
- if self .args .distributed_dataloader :
1518
- logger .info ("Eval using DistDataLoader." )
1519
-
1520
- return DistDataLoader (
1514
+ additional_configs = {}
1515
+ return _DataLoader (
1521
1516
eval_dataset ,
1522
- batch_sampler = eval_sampler ,
1517
+ batch_size = self . args . per_device_eval_batch_size ,
1523
1518
collate_fn = self .data_collator ,
1524
- num_workers = self . args . dataloader_num_workers ,
1525
- eval = True ,
1519
+ num_workers = 0 ,
1520
+ ** additional_configs ,
1526
1521
)
1527
1522
else :
1528
- return DataLoader (
1523
+ eval_sampler = self ._get_eval_sampler (eval_dataset )
1524
+ if self .args .distributed_dataloader :
1525
+ logger .info ("Eval using DistDataLoader." )
1526
+ additional_configs = {"eval" : True }
1527
+ else :
1528
+ additional_configs = {}
1529
+ return _DataLoader (
1529
1530
eval_dataset ,
1530
1531
batch_sampler = eval_sampler ,
1531
1532
collate_fn = self .data_collator ,
1532
1533
num_workers = self .args .dataloader_num_workers ,
1534
+ ** additional_configs ,
1533
1535
)
1534
1536
1535
1537
def get_test_dataloader (self , test_dataset : Dataset ) -> DataLoader :
@@ -1548,11 +1550,16 @@ def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
1548
1550
if not self .args .should_load_dataset and test_dataset is not None :
1549
1551
raise ValueError ("We don't need test_dataset when should_load_dataset is False." )
1550
1552
1553
+ if self .args .distributed_dataloader :
1554
+ is_iterable_dataset = self ._is_iterable_dataset_distributed (test_dataset )
1555
+ else :
1556
+ is_iterable_dataset = self ._is_iterable_dataset (test_dataset )
1551
1557
if is_datasets_available () and test_dataset is not None and isinstance (test_dataset , datasets .Dataset ):
1552
1558
test_dataset = self ._remove_unused_columns (test_dataset , description = "test" )
1559
+ _DataLoader = DistDataLoader if self .args .distributed_dataloader else DataLoader
1553
1560
1554
- if self . _is_iterable_dataset ( test_dataset ) :
1555
- if self .args .dataset_world_size > 1 :
1561
+ if is_iterable_dataset :
1562
+ if self .args .dataset_world_size > 1 and test_dataset is not None :
1556
1563
test_dataset = IterableDatasetShard (
1557
1564
test_dataset ,
1558
1565
batch_size = self .args .per_device_eval_batch_size ,
@@ -1562,40 +1569,31 @@ def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
1562
1569
)
1563
1570
1564
1571
if self .args .distributed_dataloader :
1565
- return DistDataLoader (
1566
- test_dataset ,
1567
- batch_size = self .args .per_device_eval_batch_size * self .world_size ,
1568
- collate_fn = self .data_collator , # _get_collator_with_removed_columns
1569
- num_workers = self .args .dataloader_num_workers ,
1570
- eval = True ,
1571
- )
1572
+ logger .info ("Test using DistDataLoader." )
1573
+ additional_config = {"eval" : True , "is_iterable_dataset" : True }
1572
1574
else :
1573
- return DataLoader (
1574
- test_dataset ,
1575
- batch_size = self .args .per_device_eval_batch_size * self .world_size ,
1576
- collate_fn = self .data_collator , # _get_collator_with_removed_columns
1577
- num_workers = self .args .dataloader_num_workers ,
1578
- )
1579
-
1580
- test_sampler = self ._get_eval_sampler (test_dataset )
1581
-
1582
- if self .args .distributed_dataloader :
1583
- logger .info ("Test using DistDataLoader." )
1584
-
1585
- # We use the same batch_size as for eval.
1586
- return DistDataLoader (
1575
+ additional_config = {}
1576
+ return _DataLoader (
1587
1577
test_dataset ,
1588
- batch_sampler = test_sampler ,
1578
+ batch_size = self . args . per_device_eval_batch_size * self . world_size ,
1589
1579
collate_fn = self .data_collator ,
1590
- drop_last = self .args .dataloader_drop_last ,
1591
- eval = True ,
1580
+ num_workers = self .args .dataloader_num_workers ,
1581
+ ** additional_config ,
1592
1582
)
1593
1583
else :
1594
- return DataLoader (
1584
+ test_sampler = self ._get_eval_sampler (test_dataset )
1585
+ if self .args .distributed_dataloader :
1586
+ logger .info ("Test using DistDataLoader." )
1587
+ additional_config = {"eval" : True }
1588
+ else :
1589
+ additional_config = {}
1590
+ # We use the same batch_size as for eval.
1591
+ return _DataLoader (
1595
1592
test_dataset ,
1596
1593
batch_sampler = test_sampler ,
1597
1594
collate_fn = self .data_collator ,
1598
1595
drop_last = self .args .dataloader_drop_last ,
1596
+ ** additional_config ,
1599
1597
)
1600
1598
1601
1599
def create_optimizer_and_scheduler (self , num_training_steps : int ):
@@ -1700,6 +1698,8 @@ def _load_rng_state(self, checkpoint):
1700
1698
1701
1699
if self .args .use_hybrid_parallel :
1702
1700
if "hybrid_parallel_rng_state_tracker" in checkpoint_rng_state :
1701
+ if self .args .tensor_parallel_degree <= 1 :
1702
+ checkpoint_rng_state ["hybrid_parallel_rng_state_tracker" ].pop ("model_parallel_rng" , None )
1703
1703
fleet .meta_parallel .get_rng_state_tracker ().set_states_tracker (
1704
1704
checkpoint_rng_state ["hybrid_parallel_rng_state_tracker" ]
1705
1705
)
@@ -3210,6 +3210,15 @@ def _get_collator_with_removed_columns(
3210
3210
def _is_iterable_dataset (self , dataset ):
3211
3211
return isinstance (dataset , paddle .io .IterableDataset )
3212
3212
3213
+ def _is_iterable_dataset_distributed (self , dataset ):
3214
+ # For distributed dataloaer.
3215
+ is_iterable_dataset_tensor = paddle .to_tensor (self ._is_iterable_dataset (dataset )).reshape ([1 ])
3216
+ if dist .get_world_size () > 1 :
3217
+ dist .all_reduce (is_iterable_dataset_tensor , op = dist .ReduceOp .MAX )
3218
+ if is_iterable_dataset_tensor .item () == 1 :
3219
+ return True
3220
+ return False
3221
+
3213
3222
def print_config (self , args = None , key = "" ):
3214
3223
"""
3215
3224
print config values
0 commit comments