Skip to content

Commit 8fd390d

Browse files
cjldyaox12
andauthored
Fuse MLA DOWN projection GEMMs (#3039)
Signed-off-by: Xin Yao <xiny@nvidia.com> Co-authored-by: Xin Yao <xiny@nvidia.com>
1 parent f544034 commit 8fd390d

File tree

7 files changed

+621
-45
lines changed

7 files changed

+621
-45
lines changed

gpt_builders.py

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -115,43 +115,42 @@ def gpt_builder(args, pre_process, post_process, vp_stage=None, config=None, pg_
115115

116116
def _get_transformer_layer_spec(use_te, config):
117117
"""Get transformer layer specification based on configuration.
118-
118+
119119
Args:
120120
use_te (bool): Whether to use Transformer Engine
121-
args: Training arguments
122121
config: Model configuration
123-
122+
124123
Returns:
125124
transformer_layer_spec: The transformer layer specification
126125
"""
127-
args = get_args()
128126
if use_te:
129127
return get_gpt_layer_with_transformer_engine_spec(
130-
args.num_experts,
131-
args.moe_grouped_gemm,
132-
args.qk_layernorm,
133-
args.multi_latent_attention,
134-
args.experimental_attention_variant,
135-
qk_l2_norm=args.qk_l2_norm,
128+
config.num_moe_experts,
129+
config.moe_grouped_gemm,
130+
config.qk_layernorm,
131+
config.multi_latent_attention,
132+
config.experimental_attention_variant,
133+
qk_l2_norm=config.qk_l2_norm,
136134
use_kitchen=config.use_kitchen,
137135
use_te_activation_func=config.use_te_activation_func,
138136
use_kitchen_attention=config.use_kitchen_attention,
139137
kitchen_attention_backend=config.kitchen_attention_backend,
138+
mla_down_proj_fusion=getattr(config, "mla_down_proj_fusion", False),
140139
)
141140
elif config.transformer_impl == "inference_optimized":
142141
return get_gpt_layer_with_inference_spec(
143-
args.qk_layernorm,
144-
args.multi_latent_attention,
145-
qk_l2_norm=args.qk_l2_norm,
142+
config.qk_layernorm,
143+
config.multi_latent_attention,
144+
qk_l2_norm=config.qk_l2_norm,
146145
)
147146
else:
148147
return get_gpt_layer_local_spec(
149-
args.num_experts,
150-
args.moe_grouped_gemm,
151-
args.qk_layernorm,
152-
args.multi_latent_attention,
153-
args.experimental_attention_variant,
154-
normalization=args.normalization,
148+
config.num_moe_experts,
149+
config.moe_grouped_gemm,
150+
config.qk_layernorm,
151+
config.multi_latent_attention,
152+
config.experimental_attention_variant,
153+
normalization=config.normalization,
155154
use_kitchen=config.use_kitchen,
156155
use_kitchen_attention=config.use_kitchen_attention,
157156
kitchen_attention_backend=config.kitchen_attention_backend,

megatron/core/models/gpt/experimental_attention_variant_module_specs.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -397,6 +397,7 @@ def _get_self_attention_module_spec(
397397
use_te_activation_func=config.use_te_activation_func,
398398
use_kitchen_attention=config.use_kitchen_attention,
399399
kitchen_attention_backend=config.kitchen_attention_backend,
400+
mla_down_proj_fusion=getattr(config, "mla_down_proj_fusion", False),
400401
)
401402
attn_spec = layer_spec.submodules.self_attention
402403
if config.multi_latent_attention:

megatron/core/models/gpt/gpt_layer_specs.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from megatron.core.transformer.identity_op import IdentityOp
1515
from megatron.core.transformer.mlp import MLP, MLPSubmodules
1616
from megatron.core.transformer.multi_latent_attention import (
17+
FusedMLASelfAttention,
1718
MLASelfAttention,
1819
MLASelfAttentionSubmodules,
1920
)
@@ -184,6 +185,7 @@ def get_gpt_layer_with_transformer_engine_submodules(
184185
use_te_activation_func: bool = False,
185186
use_kitchen_attention: bool = False,
186187
kitchen_attention_backend: str = "sdpa",
188+
mla_down_proj_fusion: bool = False,
187189
) -> TransformerLayerSubmodules:
188190
"""Use these submodules to use lower-level Transformer Engine modules (required for fp8
189191
training).
@@ -198,6 +200,9 @@ def get_gpt_layer_with_transformer_engine_submodules(
198200
qk_l2_norm (bool, optional): To use l2 norm for queries/keys. Defaults to False.
199201
use_te_op_fuser (bool, optional): Use Transformer Engine's operation-based API, which may
200202
enable certain operation fusions. Defaults to False.
203+
mla_down_proj_fusion (bool, optional): Enable fused q/kv down-projection and fused input
204+
layernorm when backend supports. Otherwise fall back
205+
to the unfused MLA.
201206
202207
Returns:
203208
TransformerLayerSubmodules: TE modules to construct a TransformerLayer
@@ -243,6 +248,45 @@ def get_gpt_layer_with_transformer_engine_submodules(
243248
if qk_layernorm
244249
else backend.column_parallel_linear()
245250
)
251+
252+
if mla_down_proj_fusion:
253+
fuse_input_layernorm = backend.column_parallel_layer_norm_linear() is not None
254+
input_layernorm = IdentityOp if fuse_input_layernorm else backend.layer_norm()
255+
down_proj_linear = (
256+
backend.column_parallel_layer_norm_linear()
257+
if fuse_input_layernorm
258+
else backend.linear()
259+
)
260+
return TransformerLayerSubmodules(
261+
input_layernorm=input_layernorm,
262+
self_attention=ModuleSpec(
263+
module=FusedMLASelfAttention,
264+
params={"attn_mask_type": AttnMaskType.causal},
265+
submodules=MLASelfAttentionSubmodules(
266+
linear_q_proj=backend.column_parallel_linear(),
267+
linear_qkv_down_proj=down_proj_linear,
268+
linear_q_up_proj=linear_q_up_proj,
269+
linear_kv_up_proj=linear_kv_up_proj,
270+
core_attention=backend.core_attention(),
271+
linear_proj=backend.row_parallel_linear(),
272+
q_layernorm=IdentityOp,
273+
kv_layernorm=IdentityOp,
274+
),
275+
),
276+
self_attn_bda=get_bias_dropout_add,
277+
pre_mlp_layernorm=backend.layer_norm() if num_experts else IdentityOp,
278+
mlp=mlp,
279+
mlp_bda=get_bias_dropout_add,
280+
sharded_state_dict_keys_map=(
281+
{
282+
"self_attention.linear_q_down_proj.layer_norm_": "input_layernorm.",
283+
"self_attention.linear_kv_down_proj.layer_norm_": "input_layernorm.",
284+
"self_attention.linear_qkv_down_proj.layer_norm_": "input_layernorm.",
285+
}
286+
if fuse_input_layernorm
287+
else {}
288+
),
289+
)
246290
return TransformerLayerSubmodules(
247291
input_layernorm=backend.layer_norm(has_residual=True),
248292
self_attention=ModuleSpec(
@@ -526,6 +570,7 @@ def get_gpt_decoder_layer_specs(
526570
use_te_activation_func=config.use_te_activation_func,
527571
use_kitchen_attention=config.use_kitchen_attention,
528572
kitchen_attention_backend=config.kitchen_attention_backend,
573+
mla_down_proj_fusion=getattr(config, "mla_down_proj_fusion", False),
529574
)
530575
moe_layer_spec = get_gpt_layer_with_transformer_engine_spec(
531576
num_experts=config.num_moe_experts,
@@ -537,6 +582,7 @@ def get_gpt_decoder_layer_specs(
537582
use_te_activation_func=config.use_te_activation_func,
538583
use_kitchen_attention=config.use_kitchen_attention,
539584
kitchen_attention_backend=config.kitchen_attention_backend,
585+
mla_down_proj_fusion=getattr(config, "mla_down_proj_fusion", False),
540586
)
541587
elif config.transformer_impl == "inference_optimized":
542588
layer_norm_impl = TENorm

0 commit comments

Comments
 (0)