Skip to content

Commit 1e9b5a8

Browse files
authored
[unified checkpoint] Update unified checkpoint (#7730)
* replace with unwrap_optimizer * remove dist.barrier(group=tp_group) for distributed_gather * add warning when loading model weights as master weights * update paddle.cast * Revert "replace with unwrap_optimizer" This reverts commit 3fd6a58. * add unwrap_optimizer * convert is_unified_checkpoint to is_unified_checkpoint * update
1 parent d1e51e2 commit 1e9b5a8

File tree

3 files changed

+50
-47
lines changed

3 files changed

+50
-47
lines changed

paddlenlp/trainer/plugins/unified_checkpoint.py

Lines changed: 16 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def save_unified_checkpoint(args, model, optimizer, output_dir, safe_serializati
116116
raise ValueError("Unified checkpoint only supports PretrainedModel")
117117

118118
if UnifiedCheckpointOption.SKIP_SAVE_MODEL_WEIGHT.value in args.unified_checkpoint_config:
119-
if is_need_master_weight(args, optimizer):
119+
if is_need_master_weight(optimizer, is_fp16_or_bp16=(args.fp16 or args.bf16)):
120120
logger.info(
121121
f"With {UnifiedCheckpointOption.SKIP_SAVE_MODEL_WEIGHT.value}, skip the model checkpoint save."
122122
"The master weight will be loaded as model weights for next resumption."
@@ -237,9 +237,6 @@ def _remove_unused_keys(
237237
None, model.config, state_dict=state_dict, ignore_error=len(resolved_archive_file) > 1
238238
)
239239

240-
# confirm parameter cast is executed on the same device as model
241-
# TODO: cast(FP32 -> FP16) has diff on different devices, need to fix it
242-
state_dict = nested_copy_place(state_dict, place=paddle.framework._current_expected_place())
243240
error_msgs += _load_state_dict_into_model(model, state_dict, "")
244241

245242
# force memory release
@@ -1388,7 +1385,6 @@ def merge_tensor_parallel_with_shard(state_dict, tp_actions, all_filter_keys):
13881385
tensor = state_dict[key]
13891386
if key in tp_actions:
13901387
ret = distributed_gather(tensor, dst=j, group=tp_group, offload=False)
1391-
dist.barrier(group=tp_group)
13921388
action = tp_actions.pop(key)
13931389
tensor = action(ret) if is_dst else None
13941390
else:
@@ -1429,7 +1425,6 @@ def merge_tensor_parallel_for_optimizer(state_dict, tp_actions, all_filter_keys)
14291425
) # Need broadcast when loaded
14301426
else:
14311427
ret = distributed_gather(tensor, dst=j, group=tp_group, offload=False)
1432-
dist.barrier(group=tp_group)
14331428
action = tp_actions[model_key]
14341429
tensor = action(ret) if is_dst else None
14351430
else:
@@ -1631,13 +1626,17 @@ def select_model_weight_index(args, model, resume_from_checkpoint, safe_serializ
16311626

16321627

16331628
def update_master_weight_status(args, optimizer, has_master_weight, safe_serialization):
1634-
if is_need_master_weight(args, optimizer):
1629+
if is_need_master_weight(optimizer, is_fp16_or_bp16=(args.fp16 or args.bf16)):
16351630
if not has_master_weight:
16361631
if UnifiedCheckpointOption.MASTER_WEIGHT_COMPATIBLE.value in args.unified_checkpoint_config:
16371632
index_filename_master_weights = (
16381633
PADDLE_WEIGHTS_INDEX_NAME if not safe_serialization else SAFE_WEIGHTS_INDEX_NAME
16391634
)
16401635
has_master_weight = True
1636+
logger.warning(
1637+
"The unified checkpoint does not contain master weight, "
1638+
"the model weight will be loaded as master weight."
1639+
)
16411640
else:
16421641
raise ValueError(
16431642
"Can't find a valid unified master weight checkpoint,"
@@ -1656,28 +1655,19 @@ def update_master_weight_status(args, optimizer, has_master_weight, safe_seriali
16561655
return has_master_weight, index_filename_master_weights
16571656

16581657

1659-
def is_need_master_weight(args, optimizer):
1660-
"""
1661-
https://github.com/PaddlePaddle/Paddle/blob/4a9991fb6744443333638b65fb7e225fb2b00a13/python/paddle/amp/auto_cast.py#L485
1662-
"""
1658+
def unwrap_optimizer(optimizer):
1659+
while hasattr(optimizer, "_inner_opt") or hasattr(optimizer, "_optim"):
1660+
if hasattr(optimizer, "_inner_opt"):
1661+
optimizer = optimizer._inner_opt
1662+
if hasattr(optimizer, "_optim"):
1663+
optimizer = optimizer._optim
16631664

1664-
from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.dygraph_sharding_optimizer import (
1665-
DygraphShardingOptimizer,
1666-
DygraphShardingOptimizerV2,
1667-
)
1668-
from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.hybrid_parallel_optimizer import (
1669-
HybridParallelOptimizer,
1670-
)
1671-
from paddle.distributed.fleet.meta_parallel.sharding.group_sharded_optimizer_stage2 import (
1672-
GroupShardedOptimizerStage2,
1673-
)
1665+
return optimizer
16741666

1675-
if isinstance(optimizer, (DygraphShardingOptimizer, DygraphShardingOptimizerV2, HybridParallelOptimizer)):
1676-
optimizer = optimizer._inner_opt
1677-
elif isinstance(optimizer, GroupShardedOptimizerStage2):
1678-
optimizer = optimizer._optim
16791667

1668+
def is_need_master_weight(optimizer, is_fp16_or_bp16):
1669+
optimizer = unwrap_optimizer(optimizer)
16801670
if hasattr(optimizer, "_multi_precision"):
1681-
return optimizer._multi_precision and (args.bf16 or args.fp16)
1671+
return optimizer._multi_precision and is_fp16_or_bp16
16821672
else:
16831673
return False

paddlenlp/trainer/trainer.py

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -88,9 +88,12 @@
8888
from ..utils.batch_sampler import DistributedBatchSampler as NlpDistributedBatchSampler
8989
from ..utils.env import (
9090
LORA_WEIGHTS_NAME,
91+
PADDLE_MASTER_WEIGHTS_INDEX_NAME,
9192
PADDLE_WEIGHTS_INDEX_NAME,
9293
PADDLE_WEIGHTS_NAME,
9394
PREFIX_WEIGHTS_NAME,
95+
SAFE_MASTER_WEIGHTS_INDEX_NAME,
96+
SAFE_WEIGHTS_INDEX_NAME,
9497
)
9598
from ..utils.import_utils import is_datasets_available, is_paddle_cuda_available
9699
from ..utils.log import logger
@@ -507,9 +510,10 @@ def _load_from_checkpoint(self, resume_from_checkpoint=None):
507510

508511
if self.args.unified_checkpoint:
509512
if resume_from_checkpoint is not None:
510-
use_unified_checkpoint = True
511-
if self.check_origin_checkpoint(resume_from_checkpoint):
512-
use_unified_checkpoint = False
513+
use_unified_checkpoint = False
514+
if self.is_unified_checkpoint(resume_from_checkpoint):
515+
use_unified_checkpoint = True
516+
else:
513517
logger.info("Loading origin checkpoint, the next checkpoint will be saved as unified checkpoint")
514518

515519
if use_unified_checkpoint:
@@ -2285,11 +2289,11 @@ def _load_optimizer_and_scheduler(self, checkpoint):
22852289
checkpoint, OPTIMIZER_NAME, self.model_wrapped
22862290
)
22872291
else:
2288-
use_unified_checkpoint = False
22892292
if self.args.unified_checkpoint:
2290-
use_unified_checkpoint = True
2291-
if self.check_origin_checkpoint(checkpoint):
2292-
use_unified_checkpoint = False
2293+
use_unified_checkpoint = False
2294+
if self.is_unified_checkpoint(checkpoint):
2295+
use_unified_checkpoint = True
2296+
else:
22932297
logger.info("Loading checkpoint, the next checkpoint will be saved as unified checkpoint")
22942298

22952299
if not use_unified_checkpoint:
@@ -2940,20 +2944,22 @@ def print_config(self, args=None, key=""):
29402944

29412945
logger.info("")
29422946

2943-
def check_origin_checkpoint(self, resume_from_checkpoint):
2944-
is_origin_checkpoint_type = False
2945-
2946-
weight_name = PADDLE_WEIGHTS_NAME
2947-
weight_index_name = PADDLE_WEIGHTS_INDEX_NAME
2948-
weights_file = os.path.join(
2949-
resume_from_checkpoint,
2950-
_add_variant(weight_name, self.args.weight_name_suffix),
2947+
def is_unified_checkpoint(self, resume_from_checkpoint, safe_serialization=True):
2948+
is_unified_checkpoint_type = False
2949+
weights_index_name = PADDLE_WEIGHTS_INDEX_NAME if not safe_serialization else SAFE_WEIGHTS_INDEX_NAME
2950+
master_weights_index_name = (
2951+
PADDLE_MASTER_WEIGHTS_INDEX_NAME if not safe_serialization else SAFE_MASTER_WEIGHTS_INDEX_NAME
29512952
)
29522953
weights_index_file = os.path.join(
29532954
resume_from_checkpoint,
2954-
_add_variant(weight_index_name, self.args.weight_name_suffix),
2955+
weights_index_name,
2956+
)
2957+
master_weights_index_file = os.path.join(
2958+
resume_from_checkpoint,
2959+
master_weights_index_name,
29552960
)
2956-
if distributed_isfile(weights_file) or distributed_isfile(weights_index_file):
2957-
is_origin_checkpoint_type = True
29582961

2959-
return is_origin_checkpoint_type
2962+
if distributed_isfile(weights_index_file) or distributed_isfile(master_weights_index_file):
2963+
is_unified_checkpoint_type = True
2964+
2965+
return is_unified_checkpoint_type

paddlenlp/transformers/model_utils.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -735,15 +735,22 @@ def _convert_state_dict_dtype_and_shape(state_dict, model_to_load):
735735
def is_0d_or_1d(tensor):
736736
return len(tensor.shape) == 0 or list(tensor.shape) == [1]
737737

738+
expected_place = paddle.framework._current_expected_place()
738739
for key, value in model_to_load.state_dict().items():
739740
if key in state_dict:
740741
if isinstance(state_dict[key], np.ndarray):
741742
raise ValueError(
742743
"convert_state_dict_dtype expected paddle.Tensor not numpy.ndarray, plase convert numpy.ndarray to paddle.Tensor"
743744
)
745+
# confirm parameter cast is executed on the same device as model
746+
# TODO: cast(FP32 -> FP16) has diff on different devices, need to fix it
744747
if state_dict[key].is_floating_point() and state_dict[key].dtype != value.dtype:
745-
state_dict[key] = paddle.cast(state_dict.pop(key), value.dtype)
746-
748+
value_pop = state_dict.pop(key)
749+
value_new_place = (
750+
value_pop if value_pop.place == expected_place else value_pop._copy_to(expected_place, False)
751+
)
752+
state_dict[key] = paddle.cast(value_new_place, value.dtype)._copy_to(value_pop.place, False)
753+
del value_new_place
747754
# unified 0d and 1d tensor
748755
if is_0d_or_1d(value) and is_0d_or_1d(state_dict[key]):
749756
if list(value.shape) != list(state_dict[key].shape):

0 commit comments

Comments
 (0)