@@ -116,7 +116,7 @@ def save_unified_checkpoint(args, model, optimizer, output_dir, safe_serializati
116
116
raise ValueError ("Unified checkpoint only supports PretrainedModel" )
117
117
118
118
if UnifiedCheckpointOption .SKIP_SAVE_MODEL_WEIGHT .value in args .unified_checkpoint_config :
119
- if is_need_master_weight (args , optimizer ):
119
+ if is_need_master_weight (optimizer , is_fp16_or_bp16 = ( args . fp16 or args . bf16 ) ):
120
120
logger .info (
121
121
f"With { UnifiedCheckpointOption .SKIP_SAVE_MODEL_WEIGHT .value } , skip the model checkpoint save."
122
122
"The master weight will be loaded as model weights for next resumption."
@@ -237,9 +237,6 @@ def _remove_unused_keys(
237
237
None , model .config , state_dict = state_dict , ignore_error = len (resolved_archive_file ) > 1
238
238
)
239
239
240
- # confirm parameter cast is executed on the same device as model
241
- # TODO: cast(FP32 -> FP16) has diff on different devices, need to fix it
242
- state_dict = nested_copy_place (state_dict , place = paddle .framework ._current_expected_place ())
243
240
error_msgs += _load_state_dict_into_model (model , state_dict , "" )
244
241
245
242
# force memory release
@@ -1388,7 +1385,6 @@ def merge_tensor_parallel_with_shard(state_dict, tp_actions, all_filter_keys):
1388
1385
tensor = state_dict [key ]
1389
1386
if key in tp_actions :
1390
1387
ret = distributed_gather (tensor , dst = j , group = tp_group , offload = False )
1391
- dist .barrier (group = tp_group )
1392
1388
action = tp_actions .pop (key )
1393
1389
tensor = action (ret ) if is_dst else None
1394
1390
else :
@@ -1429,7 +1425,6 @@ def merge_tensor_parallel_for_optimizer(state_dict, tp_actions, all_filter_keys)
1429
1425
) # Need broadcast when loaded
1430
1426
else :
1431
1427
ret = distributed_gather (tensor , dst = j , group = tp_group , offload = False )
1432
- dist .barrier (group = tp_group )
1433
1428
action = tp_actions [model_key ]
1434
1429
tensor = action (ret ) if is_dst else None
1435
1430
else :
@@ -1631,13 +1626,17 @@ def select_model_weight_index(args, model, resume_from_checkpoint, safe_serializ
1631
1626
1632
1627
1633
1628
def update_master_weight_status (args , optimizer , has_master_weight , safe_serialization ):
1634
- if is_need_master_weight (args , optimizer ):
1629
+ if is_need_master_weight (optimizer , is_fp16_or_bp16 = ( args . fp16 or args . bf16 ) ):
1635
1630
if not has_master_weight :
1636
1631
if UnifiedCheckpointOption .MASTER_WEIGHT_COMPATIBLE .value in args .unified_checkpoint_config :
1637
1632
index_filename_master_weights = (
1638
1633
PADDLE_WEIGHTS_INDEX_NAME if not safe_serialization else SAFE_WEIGHTS_INDEX_NAME
1639
1634
)
1640
1635
has_master_weight = True
1636
+ logger .warning (
1637
+ "The unified checkpoint does not contain master weight, "
1638
+ "the model weight will be loaded as master weight."
1639
+ )
1641
1640
else :
1642
1641
raise ValueError (
1643
1642
"Can't find a valid unified master weight checkpoint,"
@@ -1656,28 +1655,19 @@ def update_master_weight_status(args, optimizer, has_master_weight, safe_seriali
1656
1655
return has_master_weight , index_filename_master_weights
1657
1656
1658
1657
1659
- def is_need_master_weight (args , optimizer ):
1660
- """
1661
- https://github.com/PaddlePaddle/Paddle/blob/4a9991fb6744443333638b65fb7e225fb2b00a13/python/paddle/amp/auto_cast.py#L485
1662
- """
1658
+ def unwrap_optimizer (optimizer ):
1659
+ while hasattr (optimizer , "_inner_opt" ) or hasattr (optimizer , "_optim" ):
1660
+ if hasattr (optimizer , "_inner_opt" ):
1661
+ optimizer = optimizer ._inner_opt
1662
+ if hasattr (optimizer , "_optim" ):
1663
+ optimizer = optimizer ._optim
1663
1664
1664
- from paddle .distributed .fleet .meta_optimizers .dygraph_optimizer .dygraph_sharding_optimizer import (
1665
- DygraphShardingOptimizer ,
1666
- DygraphShardingOptimizerV2 ,
1667
- )
1668
- from paddle .distributed .fleet .meta_optimizers .dygraph_optimizer .hybrid_parallel_optimizer import (
1669
- HybridParallelOptimizer ,
1670
- )
1671
- from paddle .distributed .fleet .meta_parallel .sharding .group_sharded_optimizer_stage2 import (
1672
- GroupShardedOptimizerStage2 ,
1673
- )
1665
+ return optimizer
1674
1666
1675
- if isinstance (optimizer , (DygraphShardingOptimizer , DygraphShardingOptimizerV2 , HybridParallelOptimizer )):
1676
- optimizer = optimizer ._inner_opt
1677
- elif isinstance (optimizer , GroupShardedOptimizerStage2 ):
1678
- optimizer = optimizer ._optim
1679
1667
1668
+ def is_need_master_weight (optimizer , is_fp16_or_bp16 ):
1669
+ optimizer = unwrap_optimizer (optimizer )
1680
1670
if hasattr (optimizer , "_multi_precision" ):
1681
- return optimizer ._multi_precision and ( args . bf16 or args . fp16 )
1671
+ return optimizer ._multi_precision and is_fp16_or_bp16
1682
1672
else :
1683
1673
return False
0 commit comments