Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 18 additions & 19 deletions gpt_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,43 +115,42 @@ def gpt_builder(args, pre_process, post_process, vp_stage=None, config=None, pg_

def _get_transformer_layer_spec(use_te, config):
"""Get transformer layer specification based on configuration.

Args:
use_te (bool): Whether to use Transformer Engine
args: Training arguments
config: Model configuration

Returns:
transformer_layer_spec: The transformer layer specification
"""
args = get_args()
if use_te:
return get_gpt_layer_with_transformer_engine_spec(
args.num_experts,
args.moe_grouped_gemm,
args.qk_layernorm,
args.multi_latent_attention,
args.experimental_attention_variant,
qk_l2_norm=args.qk_l2_norm,
config.num_moe_experts,
config.moe_grouped_gemm,
config.qk_layernorm,
config.multi_latent_attention,
config.experimental_attention_variant,
qk_l2_norm=config.qk_l2_norm,
use_kitchen=config.use_kitchen,
use_te_activation_func=config.use_te_activation_func,
use_kitchen_attention=config.use_kitchen_attention,
kitchen_attention_backend=config.kitchen_attention_backend,
mla_down_proj_fusion=getattr(config, "mla_down_proj_fusion", False),
)
elif config.transformer_impl == "inference_optimized":
return get_gpt_layer_with_inference_spec(
args.qk_layernorm,
args.multi_latent_attention,
qk_l2_norm=args.qk_l2_norm,
config.qk_layernorm,
config.multi_latent_attention,
qk_l2_norm=config.qk_l2_norm,
)
else:
return get_gpt_layer_local_spec(
args.num_experts,
args.moe_grouped_gemm,
args.qk_layernorm,
args.multi_latent_attention,
args.experimental_attention_variant,
normalization=args.normalization,
config.num_moe_experts,
config.moe_grouped_gemm,
config.qk_layernorm,
config.multi_latent_attention,
config.experimental_attention_variant,
normalization=config.normalization,
use_kitchen=config.use_kitchen,
use_kitchen_attention=config.use_kitchen_attention,
kitchen_attention_backend=config.kitchen_attention_backend,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,7 @@ def _get_self_attention_module_spec(
use_te_activation_func=config.use_te_activation_func,
use_kitchen_attention=config.use_kitchen_attention,
kitchen_attention_backend=config.kitchen_attention_backend,
mla_down_proj_fusion=getattr(config, "mla_down_proj_fusion", False),
)
attn_spec = layer_spec.submodules.self_attention
if config.multi_latent_attention:
Expand Down
46 changes: 46 additions & 0 deletions megatron/core/models/gpt/gpt_layer_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from megatron.core.transformer.identity_op import IdentityOp
from megatron.core.transformer.mlp import MLP, MLPSubmodules
from megatron.core.transformer.multi_latent_attention import (
FusedMLASelfAttention,
MLASelfAttention,
MLASelfAttentionSubmodules,
)
Expand Down Expand Up @@ -184,6 +185,7 @@ def get_gpt_layer_with_transformer_engine_submodules(
use_te_activation_func: bool = False,
use_kitchen_attention: bool = False,
kitchen_attention_backend: str = "sdpa",
mla_down_proj_fusion: bool = False,
) -> TransformerLayerSubmodules:
"""Use these submodules to use lower-level Transformer Engine modules (required for fp8
training).
Expand All @@ -198,6 +200,9 @@ def get_gpt_layer_with_transformer_engine_submodules(
qk_l2_norm (bool, optional): To use l2 norm for queries/keys. Defaults to False.
use_te_op_fuser (bool, optional): Use Transformer Engine's operation-based API, which may
enable certain operation fusions. Defaults to False.
mla_down_proj_fusion (bool, optional): Enable fused q/kv down-projection and fused input
layernorm when backend supports. Otherwise fall back
to the unfused MLA.

Returns:
TransformerLayerSubmodules: TE modules to construct a TransformerLayer
Expand Down Expand Up @@ -243,6 +248,45 @@ def get_gpt_layer_with_transformer_engine_submodules(
if qk_layernorm
else backend.column_parallel_linear()
)

if mla_down_proj_fusion:
fuse_input_layernorm = backend.column_parallel_layer_norm_linear() is not None
input_layernorm = IdentityOp if fuse_input_layernorm else backend.layer_norm()
down_proj_linear = (
backend.column_parallel_layer_norm_linear()
if fuse_input_layernorm
else backend.linear()
)
return TransformerLayerSubmodules(
input_layernorm=input_layernorm,
self_attention=ModuleSpec(
module=FusedMLASelfAttention,
params={"attn_mask_type": AttnMaskType.causal},
submodules=MLASelfAttentionSubmodules(
linear_q_proj=backend.column_parallel_linear(),
linear_qkv_down_proj=down_proj_linear,
linear_q_up_proj=linear_q_up_proj,
linear_kv_up_proj=linear_kv_up_proj,
core_attention=backend.core_attention(),
linear_proj=backend.row_parallel_linear(),
q_layernorm=IdentityOp,
kv_layernorm=IdentityOp,
),
),
self_attn_bda=get_bias_dropout_add,
pre_mlp_layernorm=backend.layer_norm() if num_experts else IdentityOp,
mlp=mlp,
mlp_bda=get_bias_dropout_add,
sharded_state_dict_keys_map=(
{
"self_attention.linear_q_down_proj.layer_norm_": "input_layernorm.",
"self_attention.linear_kv_down_proj.layer_norm_": "input_layernorm.",
"self_attention.linear_qkv_down_proj.layer_norm_": "input_layernorm.",
}
if fuse_input_layernorm
else {}
),
)
return TransformerLayerSubmodules(
input_layernorm=backend.layer_norm(has_residual=True),
self_attention=ModuleSpec(
Expand Down Expand Up @@ -526,6 +570,7 @@ def get_gpt_decoder_layer_specs(
use_te_activation_func=config.use_te_activation_func,
use_kitchen_attention=config.use_kitchen_attention,
kitchen_attention_backend=config.kitchen_attention_backend,
mla_down_proj_fusion=getattr(config, "mla_down_proj_fusion", False),
)
moe_layer_spec = get_gpt_layer_with_transformer_engine_spec(
num_experts=config.num_moe_experts,
Expand All @@ -537,6 +582,7 @@ def get_gpt_decoder_layer_specs(
use_te_activation_func=config.use_te_activation_func,
use_kitchen_attention=config.use_kitchen_attention,
kitchen_attention_backend=config.kitchen_attention_backend,
mla_down_proj_fusion=getattr(config, "mla_down_proj_fusion", False),
)
elif config.transformer_impl == "inference_optimized":
layer_norm_impl = TENorm
Expand Down
Loading
Loading