@@ -171,6 +171,14 @@ def get_gpt_layer_with_inference_spec(
171171 )
172172
173173
174+ import os
175+ # NOTE: These toggles are intentionally environment-driven for experimentation.
176+ # They are expected to be replaced by explicit config options.
177+ FUSE_INPUT_LAYERNORM = os .environ .get ("FUSE_INPUT_LAYERNORM" , "1" ) == "1"
178+ FUSE_QKV_DOWN_PROJ = os .environ .get ("FUSE_QKV_DOWN_PROJ" , "0" ) == "1"
179+
180+
181+
174182def get_gpt_layer_with_transformer_engine_spec (
175183 num_experts : Optional [int ] = None ,
176184 moe_grouped_gemm : Optional [bool ] = False ,
@@ -235,6 +243,31 @@ def get_gpt_layer_with_transformer_engine_spec(
235243
236244 if multi_latent_attention :
237245 assert qk_l2_norm is False , "qk_l2_norm is not supported with MLA."
246+
247+ fuse_input_layernorm = (
248+ FUSE_INPUT_LAYERNORM
249+ and FUSE_QKV_DOWN_PROJ
250+ and backend .column_parallel_layer_norm_linear () is not None
251+ )
252+
253+ input_layernorm = (
254+ IdentityOp
255+ if fuse_input_layernorm
256+ else backend .layer_norm ()
257+ )
258+
259+ linear_q_down_proj = (
260+ backend .column_parallel_layer_norm_linear ()
261+ if fuse_input_layernorm
262+ else backend .linear ()
263+ )
264+
265+ linear_kv_down_proj = (
266+ backend .column_parallel_layer_norm_linear ()
267+ if fuse_input_layernorm
268+ else backend .linear ()
269+ )
270+
238271 linear_q_up_proj = (
239272 backend .column_parallel_layer_norm_linear ()
240273 if qk_layernorm
@@ -248,15 +281,15 @@ def get_gpt_layer_with_transformer_engine_spec(
248281 return ModuleSpec (
249282 module = TransformerLayer ,
250283 submodules = TransformerLayerSubmodules (
251- input_layernorm = backend . layer_norm () ,
284+ input_layernorm = input_layernorm ,
252285 self_attention = ModuleSpec (
253286 module = MLASelfAttention ,
254287 params = {"attn_mask_type" : AttnMaskType .causal },
255288 submodules = MLASelfAttentionSubmodules (
256289 linear_q_proj = backend .column_parallel_linear (),
257- linear_q_down_proj = backend . linear () ,
290+ linear_q_down_proj = linear_q_down_proj ,
258291 linear_q_up_proj = linear_q_up_proj ,
259- linear_kv_down_proj = backend . linear () ,
292+ linear_kv_down_proj = linear_kv_down_proj ,
260293 linear_kv_up_proj = linear_kv_up_proj ,
261294 core_attention = backend .core_attention (),
262295 linear_proj = backend .row_parallel_linear (),
@@ -268,6 +301,13 @@ def get_gpt_layer_with_transformer_engine_spec(
268301 pre_mlp_layernorm = backend .layer_norm () if num_experts else IdentityOp ,
269302 mlp = mlp ,
270303 mlp_bda = get_bias_dropout_add ,
304+ sharded_state_dict_keys_map = (
305+ {
306+ "self_attention.linear_qkv_down_proj.layer_norm_" : "input_layernorm." ,
307+ }
308+ if fuse_input_layernorm
309+ else {}
310+ ),
271311 ),
272312 )
273313 else :
0 commit comments