@@ -171,6 +171,13 @@ def get_gpt_layer_with_inference_spec(
171171 )
172172
173173
174+ import os
175+ # NOTE: These toggles are intentionally environment-driven for experimentation.
176+ # They are expected to be replaced by explicit config options.
177+ FUSE_INPUT_LAYERNORM = os .environ .get ("FUSE_INPUT_LAYERNORM" , "1" ) == "1"
178+ FUSE_QKV_DOWN_PROJ = os .environ .get ("FUSE_QKV_DOWN_PROJ" , "0" ) == "1"
179+
180+
174181def get_gpt_layer_with_transformer_engine_spec (
175182 num_experts : Optional [int ] = None ,
176183 moe_grouped_gemm : Optional [bool ] = False ,
@@ -235,6 +242,31 @@ def get_gpt_layer_with_transformer_engine_spec(
235242
236243 if multi_latent_attention :
237244 assert qk_l2_norm is False , "qk_l2_norm is not supported with MLA."
245+
246+ fuse_input_layernorm = (
247+ FUSE_INPUT_LAYERNORM
248+ and FUSE_QKV_DOWN_PROJ
249+ and backend .column_parallel_layer_norm_linear () is not None
250+ )
251+
252+ input_layernorm = (
253+ IdentityOp
254+ if fuse_input_layernorm
255+ else backend .layer_norm ()
256+ )
257+
258+ linear_q_down_proj = (
259+ backend .column_parallel_layer_norm_linear ()
260+ if fuse_input_layernorm
261+ else backend .linear ()
262+ )
263+
264+ linear_kv_down_proj = (
265+ backend .column_parallel_layer_norm_linear ()
266+ if fuse_input_layernorm
267+ else backend .linear ()
268+ )
269+
238270 linear_q_up_proj = (
239271 backend .column_parallel_layer_norm_linear ()
240272 if qk_layernorm
@@ -248,15 +280,15 @@ def get_gpt_layer_with_transformer_engine_spec(
248280 return ModuleSpec (
249281 module = TransformerLayer ,
250282 submodules = TransformerLayerSubmodules (
251- input_layernorm = backend . layer_norm () ,
283+ input_layernorm = input_layernorm ,
252284 self_attention = ModuleSpec (
253285 module = MLASelfAttention ,
254286 params = {"attn_mask_type" : AttnMaskType .causal },
255287 submodules = MLASelfAttentionSubmodules (
256288 linear_q_proj = backend .column_parallel_linear (),
257- linear_q_down_proj = backend . linear () ,
289+ linear_q_down_proj = linear_q_down_proj ,
258290 linear_q_up_proj = linear_q_up_proj ,
259- linear_kv_down_proj = backend . linear () ,
291+ linear_kv_down_proj = linear_kv_down_proj ,
260292 linear_kv_up_proj = linear_kv_up_proj ,
261293 core_attention = backend .core_attention (),
262294 linear_proj = backend .row_parallel_linear (),
@@ -268,6 +300,13 @@ def get_gpt_layer_with_transformer_engine_spec(
268300 pre_mlp_layernorm = backend .layer_norm () if num_experts else IdentityOp ,
269301 mlp = mlp ,
270302 mlp_bda = get_bias_dropout_add ,
303+ sharded_state_dict_keys_map = (
304+ {
305+ "self_attention.linear_qkv_down_proj.layer_norm_" : "input_layernorm." ,
306+ }
307+ if fuse_input_layernorm
308+ else {}
309+ ),
271310 ),
272311 )
273312 else :
0 commit comments