1414from megatron .core .transformer .identity_op import IdentityOp
1515from megatron .core .transformer .mlp import MLP , MLPSubmodules
1616from megatron .core .transformer .multi_latent_attention import (
17+ FusedMLASelfAttention ,
1718 MLASelfAttention ,
1819 MLASelfAttentionSubmodules ,
1920)
@@ -184,6 +185,7 @@ def get_gpt_layer_with_transformer_engine_submodules(
184185 use_te_activation_func : bool = False ,
185186 use_kitchen_attention : bool = False ,
186187 kitchen_attention_backend : str = "sdpa" ,
188+ mla_down_proj_fusion : bool = False ,
187189) -> TransformerLayerSubmodules :
188190 """Use these submodules to use lower-level Transformer Engine modules (required for fp8
189191 training).
@@ -198,6 +200,9 @@ def get_gpt_layer_with_transformer_engine_submodules(
198200 qk_l2_norm (bool, optional): To use l2 norm for queries/keys. Defaults to False.
199201 use_te_op_fuser (bool, optional): Use Transformer Engine's operation-based API, which may
200202 enable certain operation fusions. Defaults to False.
203+ mla_down_proj_fusion (bool, optional): Enable fused q/kv down-projection and fused input
204+ layernorm when backend supports. Otherwise fall back
205+ to the unfused MLA.
201206
202207 Returns:
203208 TransformerLayerSubmodules: TE modules to construct a TransformerLayer
@@ -243,6 +248,45 @@ def get_gpt_layer_with_transformer_engine_submodules(
243248 if qk_layernorm
244249 else backend .column_parallel_linear ()
245250 )
251+
252+ if mla_down_proj_fusion :
253+ fuse_input_layernorm = backend .column_parallel_layer_norm_linear () is not None
254+ input_layernorm = IdentityOp if fuse_input_layernorm else backend .layer_norm ()
255+ down_proj_linear = (
256+ backend .column_parallel_layer_norm_linear ()
257+ if fuse_input_layernorm
258+ else backend .linear ()
259+ )
260+ return TransformerLayerSubmodules (
261+ input_layernorm = input_layernorm ,
262+ self_attention = ModuleSpec (
263+ module = FusedMLASelfAttention ,
264+ params = {"attn_mask_type" : AttnMaskType .causal },
265+ submodules = MLASelfAttentionSubmodules (
266+ linear_q_proj = backend .column_parallel_linear (),
267+ linear_qkv_down_proj = down_proj_linear ,
268+ linear_q_up_proj = linear_q_up_proj ,
269+ linear_kv_up_proj = linear_kv_up_proj ,
270+ core_attention = backend .core_attention (),
271+ linear_proj = backend .row_parallel_linear (),
272+ q_layernorm = IdentityOp ,
273+ kv_layernorm = IdentityOp ,
274+ ),
275+ ),
276+ self_attn_bda = get_bias_dropout_add ,
277+ pre_mlp_layernorm = backend .layer_norm () if num_experts else IdentityOp ,
278+ mlp = mlp ,
279+ mlp_bda = get_bias_dropout_add ,
280+ sharded_state_dict_keys_map = (
281+ {
282+ "self_attention.linear_q_down_proj.layer_norm_" : "input_layernorm." ,
283+ "self_attention.linear_kv_down_proj.layer_norm_" : "input_layernorm." ,
284+ "self_attention.linear_qkv_down_proj.layer_norm_" : "input_layernorm." ,
285+ }
286+ if fuse_input_layernorm
287+ else {}
288+ ),
289+ )
246290 return TransformerLayerSubmodules (
247291 input_layernorm = backend .layer_norm (has_residual = True ),
248292 self_attention = ModuleSpec (
@@ -526,6 +570,7 @@ def get_gpt_decoder_layer_specs(
526570 use_te_activation_func = config .use_te_activation_func ,
527571 use_kitchen_attention = config .use_kitchen_attention ,
528572 kitchen_attention_backend = config .kitchen_attention_backend ,
573+ mla_down_proj_fusion = getattr (config , "mla_down_proj_fusion" , False ),
529574 )
530575 moe_layer_spec = get_gpt_layer_with_transformer_engine_spec (
531576 num_experts = config .num_moe_experts ,
@@ -537,6 +582,7 @@ def get_gpt_decoder_layer_specs(
537582 use_te_activation_func = config .use_te_activation_func ,
538583 use_kitchen_attention = config .use_kitchen_attention ,
539584 kitchen_attention_backend = config .kitchen_attention_backend ,
585+ mla_down_proj_fusion = getattr (config , "mla_down_proj_fusion" , False ),
540586 )
541587 elif config .transformer_impl == "inference_optimized" :
542588 layer_norm_impl = TENorm
0 commit comments