@@ -171,6 +171,12 @@ def get_gpt_layer_with_inference_spec(
171171 )
172172
173173
174+ import os
175+ # NOTE: These toggles are intentionally environment-driven for experimentation.
176+ # They are expected to be replaced by explicit config options.
177+ FUSE_INPUT_LAYERNORM = os .environ .get ("FUSE_INPUT_LAYERNORM" , "1" ) == "1"
178+ FUSE_QKV_DOWN_PROJ = os .environ .get ("FUSE_QKV_DOWN_PROJ" , "0" ) == "1"
179+
174180def get_gpt_layer_with_transformer_engine_spec (
175181 num_experts : Optional [int ] = None ,
176182 moe_grouped_gemm : Optional [bool ] = False ,
@@ -235,6 +241,31 @@ def get_gpt_layer_with_transformer_engine_spec(
235241
236242 if multi_latent_attention :
237243 assert qk_l2_norm is False , "qk_l2_norm is not supported with MLA."
244+
245+ fuse_input_layernorm = (
246+ FUSE_INPUT_LAYERNORM
247+ and FUSE_QKV_DOWN_PROJ
248+ and backend .column_parallel_layer_norm_linear () is not None
249+ )
250+
251+ input_layernorm = (
252+ IdentityOp
253+ if fuse_input_layernorm
254+ else backend .layer_norm ()
255+ )
256+
257+ linear_q_down_proj = (
258+ backend .column_parallel_layer_norm_linear ()
259+ if fuse_input_layernorm
260+ else backend .linear ()
261+ )
262+
263+ linear_kv_down_proj = (
264+ backend .column_parallel_layer_norm_linear ()
265+ if fuse_input_layernorm
266+ else backend .linear ()
267+ )
268+
238269 linear_q_up_proj = (
239270 backend .column_parallel_layer_norm_linear ()
240271 if qk_layernorm
@@ -248,15 +279,15 @@ def get_gpt_layer_with_transformer_engine_spec(
248279 return ModuleSpec (
249280 module = TransformerLayer ,
250281 submodules = TransformerLayerSubmodules (
251- input_layernorm = backend . layer_norm () ,
282+ input_layernorm = input_layernorm ,
252283 self_attention = ModuleSpec (
253284 module = MLASelfAttention ,
254285 params = {"attn_mask_type" : AttnMaskType .causal },
255286 submodules = MLASelfAttentionSubmodules (
256287 linear_q_proj = backend .column_parallel_linear (),
257- linear_q_down_proj = backend . linear () ,
288+ linear_q_down_proj = linear_q_down_proj ,
258289 linear_q_up_proj = linear_q_up_proj ,
259- linear_kv_down_proj = backend . linear () ,
290+ linear_kv_down_proj = linear_kv_down_proj ,
260291 linear_kv_up_proj = linear_kv_up_proj ,
261292 core_attention = backend .core_attention (),
262293 linear_proj = backend .row_parallel_linear (),
@@ -268,6 +299,13 @@ def get_gpt_layer_with_transformer_engine_spec(
268299 pre_mlp_layernorm = backend .layer_norm () if num_experts else IdentityOp ,
269300 mlp = mlp ,
270301 mlp_bda = get_bias_dropout_add ,
302+ sharded_state_dict_keys_map = (
303+ {
304+ "self_attention.linear_qkv_down_proj.layer_norm_" : "input_layernorm." ,
305+ }
306+ if fuse_input_layernorm
307+ else {}
308+ ),
271309 ),
272310 )
273311 else :
0 commit comments