Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions examples/auto_parallel/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,3 @@ should be replaced according to the real environment.


The toolkit provides an auto-parallel solution for ERNIE-4.5 pre-training, including the hybrid parallelism training strategy. More advanced optimizations are on the way.


Currently, the auto-parallel intermediate API has some limitations under ongoing development:

- Limited support for MOE
- Limited support for VPP in pipeline parallelism (default USE_VPP=0 in scripts; when USE_VPP=1, basic API are used for modeling)
4 changes: 0 additions & 4 deletions examples/auto_parallel/README_zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,3 @@
- 注意,您需要将 `train_4p5_300B_A47B.sh` 中的 `master_ip` 与 `port` 根据您的环境进行替换。

该工具包提供了使用自动并行完成 ERNIE-4.5 预训练的方法,包括多维混合并行训练策略,更多的优化点和功能会基于此版本持续更新。

现在自动并行中层API存在一些局限性,正在进一步支持:
- 对 MOE 的支持不完备
- 对流水线并行中的 VPP 优化支持不完备(脚本中默认 USE_VPP=0;当设置 USE_VPP=1 时,采用基础API完成组网)
16 changes: 0 additions & 16 deletions examples/auto_parallel/models/__init__.py

This file was deleted.

265 changes: 9 additions & 256 deletions examples/auto_parallel/models/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import math
import logging
from typing import Optional, Tuple
import contextlib


from copy import deepcopy
Expand All @@ -29,8 +28,6 @@
from paddle.incubate.nn.layer.fused_dropout_add import FusedDropoutAdd
from paddle.distributed.fleet.layers.mpu.random import get_rng_state_tracker

from models.top2_gate import TopKGateFused

from paddle.distributed.auto_parallel.intermediate.tensor_parallel import (
PrepareLayerInput,
)
Expand All @@ -42,16 +39,18 @@
from paddleformers.transformers.model_utils import PretrainedModel

from models.moe_layer import (
get_gate,
MOELayer,
MoEStatics,
ErnieMLP,
ErnieMoeMLP,
ErnieMoeMLPFused,
TopKGateFused,
)
from models.configuration import ErnieMoEConfig
from utils.training_utils import get_mesh


from paddle.nn.functional.flash_attention import flash_attention
from paddle.incubate.nn.functional import fused_rotary_position_embedding as fused_rope
from paddle.incubate.nn.functional import swiglu


@dataclass
Expand Down Expand Up @@ -178,15 +177,7 @@ def scaled_dot_product_attention(
attn_weights = F.softmax_(attn_weights, axis=-1).astype(query_states.dtype)

if config.attention_probs_dropout_prob > 0.0:
if config.tensor_parallel_degree > 1:
with get_rng_state_tracker().rng_state("local_seed"):
attn_weights = F.dropout(
attn_weights,
config.attention_probs_dropout_prob,
training=training,
mode="upscale_in_train",
)
else:
with get_rng_state_tracker().rng_state("local_seed"):
attn_weights = F.dropout(
attn_weights,
config.attention_probs_dropout_prob,
Expand Down Expand Up @@ -241,74 +232,6 @@ def _expand_mask(mask, dtype, tgt_length):
)


def get_gate(
config: ErnieMoEConfig,
expert: Tuple[Tuple[int, nn.Layer]],
layer_idx: int,
ipp: int = 0,
) -> Tuple[nn.Layer, nn.LayerList]:
moe_num_experts = config.moe_num_experts
assert (
moe_num_experts >= config.moe_world_size
), f"expert moe_num_experts={moe_num_experts} >= moe_world_size={config.moe_world_size}"
assert (
moe_num_experts % config.moe_world_size == 0
), f"expert moe_num_experts={moe_num_experts} % moe_world_size={config.moe_world_size} == 0"
moe_num_experts_per_device = moe_num_experts // config.moe_world_size
experts = nn.LayerList([])
for expert_id, (experts_num, fc) in enumerate(expert):
assert experts_num % config.moe_world_size == 0
experts_to_append = []
if not hasattr(fc, "__len__"):
experts_to_append.append(fc)
if expert_id == 1:
with paddle.utils.unique_name.guard("_mm_deepcopy"):
for _ in range(experts_num - 1):
experts_to_append.append(deepcopy(fc))
else:
for _ in range(experts_num - 1):
experts_to_append.append(deepcopy(fc))
else:
experts_to_append = fc
for ex in experts_to_append:
for p in ex.parameters():
p.expert_type = f"expert_type_{expert_id}"
experts.extend(experts_to_append)

logger.info(
f"using moe-world-size: {config.moe_world_size} "
f"expert-per-device: {moe_num_experts_per_device} "
)
if config.moe_use_hard_gate and moe_num_experts <= 2:
gate = None
logger.info("MOE-GATE:-hard-gate")
else:
logger.info(f"MOE-GATE:-{config.moe_gate}")
gate = TopKGateFused(
config, layer_idx=layer_idx, group=config.moe_group, ipp=ipp
)

lm_gate, lm_experts = None, None
logger.info(f"LM-experts-{lm_experts} -- experts-{experts}")

index = 0 if config.moe_group == "dp" else 1
ep_sub_meshes = dist.auto_parallel.api.split_mesh(get_mesh(ipp), index)

for i, expert in enumerate(experts):
ep_group_id = i // moe_num_experts_per_device
if isinstance(expert, (ErnieMoeMLPFused, ErnieMoeMLP)):
experts[i].redistribute_expert(
ep_sub_meshes[ep_group_id], [dist.Replicate(), dist.Replicate()]
)
experts[i].ep_group_id = ep_group_id

if config.moe_use_aux_free:
moe_statics = MoEStatics(config, layer_idx)
else:
moe_statics = None
return gate, experts, lm_gate, lm_experts, moe_statics


class RMSNorm(nn.Layer):
def __init__(self, config, ipp=0):
super().__init__()
Expand Down Expand Up @@ -476,36 +399,6 @@ def apply_rotary_single(x, rope_emb):
return x * rope_emb[0] + rotate_half_x * rope_emb[1]


class ErnieMLP(nn.Layer):
def __init__(self, config, ipp=None, do_shard_tensor=True):
super().__init__()
self.config = config
self.ipp = ipp
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size

self.gate_proj = nn.Linear(
self.hidden_size, self.intermediate_size, bias_attr=config.use_bias
)
self.up_proj = nn.Linear(
self.hidden_size, self.intermediate_size, bias_attr=config.use_bias
)
self.down_proj = nn.Linear(
self.intermediate_size, self.hidden_size, bias_attr=config.use_bias
)

self.fuse_swiglu = config.fuse_swiglu

def forward(self, x):
if self.fuse_swiglu:
x = swiglu(self.gate_proj(x), self.up_proj(x))
else:
x = F.silu(self.gate_proj(x)) * self.up_proj(x)

out = self.down_proj(x)
return out


class ErnieAttention(nn.Layer):
def __init__(self, config, ipp: Optional[int] = None):
super().__init__()
Expand Down Expand Up @@ -696,119 +589,6 @@ def rope_attn(
return attn_output, attn_weights, past_key_value


class ErnieMoeMLP(ErnieMLP):
"""_summary_

Args:
ErnieMoeMLP (_type_): _description_
"""

def __init__(self, config, ipp=0):
"""
doc
"""
disable_ffn_model_parallel = getattr(
config, "disable_ffn_model_parallel", False
)
if disable_ffn_model_parallel:
config = deepcopy(config)
config.tensor_parallel_degree = 1
config.sequence_parallel = False

super().__init__(config, ipp, do_shard_tensor=not disable_ffn_model_parallel)
self.moe_dropout_prob = config.moe_dropout_prob
self.fuse_swiglu = config.fuse_swiglu

def redistribute_expert(self, mesh, placements):
"""
Place the experts on different devices.
"""
self.gate_proj.weight = dist.shard_tensor(
self.gate_proj.weight, mesh, placements
)
self.up_proj.weight = dist.shard_tensor(self.up_proj.weight, mesh, placements)
self.down_proj.weight = dist.shard_tensor(
self.down_proj.weight, mesh, placements
)
if self.config.use_bias:
self.gate_proj.bias = dist.shard_tensor(
self.gate_proj.bias, mesh, placements
)
self.up_proj.bias = dist.shard_tensor(self.up_proj.bias, mesh, placements)
self.down_proj.bias = dist.shard_tensor(
self.down_proj.bias, mesh, placements
)

def forward(self, x):
if self.fuse_swiglu:
x = swiglu(self.gate_proj(x), self.up_proj(x))
else:
x = F.silu(self.gate_proj(x)) * self.up_proj(x)
if self.moe_dropout_prob > 0:
with get_rng_state_tracker().rng_state("local_seed"):
x = F.dropout(x=x, p=self.moe_dropout_prob)
ret = self.down_proj(x)
return ret


class BMMLinear(nn.Layer):
def __init__(self, experts, d_in, d_out, use_bias=False):
super().__init__()
self.weight = self.create_parameter(
[experts, d_in, d_out], dtype=paddle.get_default_dtype()
)
if use_bias:
self.bias = self.create_parameter(
[experts, d_out], dtype=paddle.get_default_dtype(), is_bias=True
)
else:
self.bias = None

def forward(self, x):
"""x: [num_experts, Seq, dim]"""
if self.bias is not None:
return paddle.bmm(x, self.weight) + self.bias
return paddle.bmm(x, self.weight)


class ErnieMoeMLPFused(nn.Layer):
def __init__(self, config):
assert (
hasattr(config, "disable_ffn_model_parallel")
or config.tensor_parallel_degree == 1
), f"fused mlp only suport mp-moe, mp={config.tensor_parallel_degree}"
assert config.fuse_attn_ffn, "fused mlp only support fuse_attn_ffn"
super().__init__()
self.moe_dropout_prob = config.moe_dropout_prob
self.num_local_experts = config.moe_num_experts // config.moe_world_size
logger.info(
f"fused-expert-weight-shape: {[self.num_local_experts, config.hidden_size, config.intermediate_size]}"
)

self.up_gate_proj = BMMLinear(
self.num_local_experts, config.hidden_size, config.intermediate_size * 2
)
self.down_proj = BMMLinear(
self.num_local_experts, config.intermediate_size, config.hidden_size
)
self.fuse_swiglu = config.fuse_swiglu

def __len__(self):
return self.num_local_experts

def __iter__(self):
return (self for _ in range(1))

def forward(self, x):
if self.fuse_swiglu:
x = swiglu(self.up_gate_proj(x))
else:
gate, x = self.up_gate_proj(x).chunk(2, axis=-1)
x = F.silu(gate) * x
x = self.down_proj(x)
return x


class ErnieDecoderLayer(nn.Layer):
"""
ErnieDecoderLayer is a decoder layer in Ernie model.
Expand Down Expand Up @@ -990,16 +770,7 @@ def forward(
)
)

if (
self.config.tensor_parallel_degree > 1
and self.config.hidden_dropout_prob > 0.0
):
current_seed = (
"local_seed" if self.config.sequence_parallel else "global_seed"
)
with get_rng_state_tracker().rng_state(current_seed):
hidden_states = self.residual_add1(hidden_states, residual)
else:
with get_rng_state_tracker().rng_state("local_seed"):
hidden_states = self.residual_add1(hidden_states, residual)

residual = hidden_states
Expand All @@ -1017,16 +788,7 @@ def forward(
hidden_states = self.mlp(hidden_states)
gate_logits = None

if (
self.config.tensor_parallel_degree > 1
and self.config.hidden_dropout_prob > 0.0
):
current_seed = (
"local_seed" if self.config.sequence_parallel else "global_seed"
)
with get_rng_state_tracker().rng_state(current_seed):
hidden_states = self.residual_add2(hidden_states, residual)
else:
with get_rng_state_tracker().rng_state("local_seed"):
hidden_states = self.residual_add2(hidden_states, residual)

outputs = (hidden_states,)
Expand Down Expand Up @@ -1069,10 +831,7 @@ class ErniePretrainedModel(PretrainedModel):

def init_weights(self, layer):
"""Initialization hook"""
if self.config.tensor_parallel_degree > 1:
rng_tracker = get_rng_state_tracker().rng_state
else:
rng_tracker = contextlib.nullcontext
rng_tracker = get_rng_state_tracker().rng_state

if isinstance(
layer,
Expand Down Expand Up @@ -1893,14 +1652,8 @@ def forward(
def auto_dist_config(self, prefix=""):
if prefix != "":
assert prefix.endswith(".")
# if self.config.pipeline_parallel_degree <= 1:
# print(f"ernie use_intermediate_api:{self.config.use_intermediate_api}")
# print(f"ernie pp mode:{self.config.pipeline_schedule_mode}")
ernie_prefix = prefix + "ernie."
layers_prefix = ""
# else:
# ernie_prefix = prefix
# layers_prefix="layers.*."
config = {
"sp_config": {
"parallelize_plan": {
Expand Down
Loading