diff --git a/paddlenlp/transformers/llama/modeling_auto_pp.py b/paddlenlp/transformers/llama/modeling_auto_pp.py index bc4fed315131..befdb5442eda 100644 --- a/paddlenlp/transformers/llama/modeling_auto_pp.py +++ b/paddlenlp/transformers/llama/modeling_auto_pp.py @@ -188,6 +188,11 @@ def manual_model_split(model, stage_idx, group, mode, pp_degree): layer_lists = model.layers + # shared_params_names only support ernie + shared_params_names = [["embedding_0.w_0.dist", "ernie_lm_head_0.w_0.dist"]] + + shared_mp = build_shared_param_map(model, shared_params_names) + def _build_stage(model, stage_idx, group): new_model = None if stage_idx == 0: # 第一个model_chunk输入特殊处理 @@ -200,7 +205,7 @@ def _build_stage(model, stage_idx, group): new_model = LlamaChunk( layer_lists[stage_idx * chunk_size : (stage_idx + 1) * chunk_size], is_first=False, is_last=False ) - stage = PipelineStage(new_model, stage_idx, chunk_num, group=group) + stage = PipelineStage(new_model, stage_idx, chunk_num, group=group, shared_parameters=shared_mp) return stage stages = [] @@ -210,6 +215,25 @@ def _build_stage(model, stage_idx, group): return stages +def build_shared_param_map(model, shared_params_names): + shared_mp = [] + for pair in shared_params_names: + assert len(pair) == 2, "Only exactly two parameters are supported for sharing." + ori_name = pair[0] + sync_name = pair[1] + ori_param = get_param_from_name(ori_name, model) + sync_param = get_param_from_name(sync_name, model) + shared_mp.append({"params": [ori_param, sync_param]}) + return shared_mp + + +def get_param_from_name(param_name, model): + for param in model.parameters(): + if param.name == param_name: + return param + raise ValueError(f"{param_name} not found in model parameters") + + def get_llama_pp_schedule(model, n_microbatches, loss_fn, mode, pp_degree, group): assert mode in ["VPP", "1F1B", "FThenB"] stages = manual_model_split(model, group.rank, group, mode, pp_degree)