diff --git a/paddlenlp/transformers/model_utils.py b/paddlenlp/transformers/model_utils.py index ed8048432243..7c95a9b307fb 100644 --- a/paddlenlp/transformers/model_utils.py +++ b/paddlenlp/transformers/model_utils.py @@ -32,6 +32,7 @@ import aistudio_sdk import numpy as np import paddle +import paddle.distributed as dist import paddle.nn as nn import six from huggingface_hub import ( @@ -2988,14 +2989,40 @@ def _generate_auto_dist_config(self, auto_dist_degree): "pp_config": None, "cp_config": None, } + has_auto_dist_config = False for name, layer in self.named_sublayers(include_self=True): if hasattr(layer, "auto_dist_config"): + has_auto_dist_config = True if name != "": prefix = name + "." else: prefix = "" layer_config = layer.auto_dist_config(prefix) merged_config = self.merge_auto_dist_configs([merged_config, layer_config]) + if not has_auto_dist_config: + model_file = inspect.getfile(self.__class__) + model_dir = os.path.dirname(model_file) + config_path = os.path.join(model_dir, "intermediate_api_config.json") + assert os.path.exists(config_path), ( + f"intermediate api config file not found at {config_path}. " + "Please ensure the file exists or implement auto_dist_config in layers." + ) + with open(config_path, "r") as f: + iconfig = json.load(f) + + def process_config_value(value): + if isinstance(value, dict): + return {k: process_config_value(v) for k, v in value.items()} + elif isinstance(value, list): + return [getattr(dist, v)() for v in value if hasattr(dist, v)] + elif isinstance(value, str) and hasattr(dist, value): + return getattr(dist, value)() + return value + + iconfig = {k: process_config_value(v) for k, v in iconfig.items()} + layer_config = {f"{k}": v for k, v in iconfig.items()} + merged_config = self.merge_auto_dist_configs([merged_config, layer_config]) + final_config = { "dp_config": None, "mp_config": None, diff --git a/paddlenlp/transformers/qwen/intermediate_api_config.json b/paddlenlp/transformers/qwen/intermediate_api_config.json new file mode 100644 index 000000000000..bc1e3e3fede1 --- /dev/null +++ b/paddlenlp/transformers/qwen/intermediate_api_config.json @@ -0,0 +1,36 @@ +{ + "sp_config": { + "parallelize_plan": { + "qwen.wte": [ + "RowWiseParallel", + "SequenceParallelBegin" + ], + "qwen.h.*.attn.c_attn": "ColWiseParallel", + "qwen.h.*.attn.c_proj": "RowWiseParallel", + "qwen.h.*.attn": "SequenceParallelDisable", + "qwen.h.*.mlp.gate_up_fused_proj": "ColWiseParallel", + "qwen.h.*.mlp.w1": "ColWiseParallel", + "qwen.h.*.mlp.w2": "ColWiseParallel", + "qwen.h.*.mlp.c_proj": "RowWiseParallel", + "qwen.h.*.mlp": "SequenceParallelDisable", + "lm_head.weight": "ColWiseParallel", + "lm_head": "SequenceParallelEnd" + } + }, + "mp_config": { + "parallelize_plan": { + "qwen.wte": "RowWiseParallel", + "qwen.h.*.attn.c_attn": "ColWiseParallel", + "qwen.h.*.attn.c_proj": "RowWiseParallel", + "qwen.h.*.mlp.gate_up_fused_proj": "ColWiseParallel", + "qwen.h.*.mlp.w1": "ColWiseParallel", + "qwen.h.*.mlp.w2": "ColWiseParallel", + "qwen.h.*.mlp.c_proj": "RowWiseParallel", + "lm_head.weight": "ColWiseParallel" + } + }, + "pp_config": { + "split_spec": "qwen.h", + "global_spec": "qwen.global_layer" + } +} diff --git a/paddlenlp/transformers/qwen/modeling_network.py b/paddlenlp/transformers/qwen/modeling_network.py index a8eddb3437ea..2e92e1191322 100644 --- a/paddlenlp/transformers/qwen/modeling_network.py +++ b/paddlenlp/transformers/qwen/modeling_network.py @@ -17,7 +17,6 @@ import warnings import paddle -import paddle.distributed as dist import paddle.nn.functional as F from paddle import nn from paddle.distributed.fleet.utils import recompute @@ -660,48 +659,6 @@ def forward( return lm_logits - def auto_dist_config(self, prefix=""): - if prefix != "": - assert prefix.endswith(".") - config = { - "sp_config": { - "parallelize_plan": { - f"{prefix}qwen.wte": [ - dist.RowWiseParallel(), - dist.SequenceParallelBegin(), - ], - f"{prefix}qwen.h.*.attn.c_attn": dist.ColWiseParallel(), - f"{prefix}qwen.h.*.attn.c_proj": dist.RowWiseParallel(), - f"{prefix}qwen.h.*.attn": dist.SequenceParallelDisable(), - f"{prefix}qwen.h.*.mlp.gate_up_fused_proj": dist.ColWiseParallel(), - f"{prefix}qwen.h.*.mlp.w1": dist.ColWiseParallel(), - f"{prefix}qwen.h.*.mlp.w2": dist.ColWiseParallel(), - f"{prefix}qwen.h.*.mlp.c_proj": dist.RowWiseParallel(), - f"{prefix}qwen.h.*.mlp": dist.SequenceParallelDisable(need_transpose=False), - f"{prefix}lm_head.weight": dist.ColWiseParallel(), - f"{prefix}lm_head": dist.SequenceParallelEnd(), - } - }, - "mp_config": { - "parallelize_plan": { - f"{prefix}qwen.wte": dist.RowWiseParallel(), - f"{prefix}qwen.h.*.attn.c_attn": dist.ColWiseParallel(), - f"{prefix}qwen.h.*.attn.c_proj": dist.RowWiseParallel(), - f"{prefix}qwen.h.*.mlp.gate_up_fused_proj": dist.ColWiseParallel(), - f"{prefix}qwen.h.*.mlp.w1": dist.ColWiseParallel(), - f"{prefix}qwen.h.*.mlp.w2": dist.ColWiseParallel(), - f"{prefix}qwen.h.*.mlp.c_proj": dist.RowWiseParallel(), - f"{prefix}lm_head.weight": dist.ColWiseParallel(), - } - }, - "pp_config": { - "split_spec": f"{prefix}qwen.h", - "global_spec": f"{prefix}qwen.global_layer", - }, - } - - return config - class GlobalNet(nn.Layer): def __init__(self, config) -> None: