Skip to content

Commit 7c41eb3

Browse files
committed
WIP
1 parent bbbedbb commit 7c41eb3

File tree

2 files changed

+279
-64
lines changed

2 files changed

+279
-64
lines changed

megatron/core/models/gpt/gpt_layer_specs.py

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,12 @@ def get_gpt_layer_with_inference_spec(
171171
)
172172

173173

174+
import os
175+
# NOTE: These toggles are intentionally environment-driven for experimentation.
176+
# They are expected to be replaced by explicit config options.
177+
FUSE_INPUT_LAYERNORM = os.environ.get("FUSE_INPUT_LAYERNORM", "1") == "1"
178+
FUSE_QKV_DOWN_PROJ = os.environ.get("FUSE_QKV_DOWN_PROJ", "0") == "1"
179+
174180
def get_gpt_layer_with_transformer_engine_spec(
175181
num_experts: Optional[int] = None,
176182
moe_grouped_gemm: Optional[bool] = False,
@@ -235,6 +241,31 @@ def get_gpt_layer_with_transformer_engine_spec(
235241

236242
if multi_latent_attention:
237243
assert qk_l2_norm is False, "qk_l2_norm is not supported with MLA."
244+
245+
fuse_input_layernorm = (
246+
FUSE_INPUT_LAYERNORM
247+
and FUSE_QKV_DOWN_PROJ
248+
and backend.column_parallel_layer_norm_linear() is not None
249+
)
250+
251+
input_layernorm = (
252+
IdentityOp
253+
if fuse_input_layernorm
254+
else backend.layer_norm()
255+
)
256+
257+
linear_q_down_proj = (
258+
backend.column_parallel_layer_norm_linear()
259+
if fuse_input_layernorm
260+
else backend.linear()
261+
)
262+
263+
linear_kv_down_proj = (
264+
backend.column_parallel_layer_norm_linear()
265+
if fuse_input_layernorm
266+
else backend.linear()
267+
)
268+
238269
linear_q_up_proj = (
239270
backend.column_parallel_layer_norm_linear()
240271
if qk_layernorm
@@ -248,15 +279,15 @@ def get_gpt_layer_with_transformer_engine_spec(
248279
return ModuleSpec(
249280
module=TransformerLayer,
250281
submodules=TransformerLayerSubmodules(
251-
input_layernorm=backend.layer_norm(),
282+
input_layernorm=input_layernorm,
252283
self_attention=ModuleSpec(
253284
module=MLASelfAttention,
254285
params={"attn_mask_type": AttnMaskType.causal},
255286
submodules=MLASelfAttentionSubmodules(
256287
linear_q_proj=backend.column_parallel_linear(),
257-
linear_q_down_proj=backend.linear(),
288+
linear_q_down_proj=linear_q_down_proj,
258289
linear_q_up_proj=linear_q_up_proj,
259-
linear_kv_down_proj=backend.linear(),
290+
linear_kv_down_proj=linear_kv_down_proj,
260291
linear_kv_up_proj=linear_kv_up_proj,
261292
core_attention=backend.core_attention(),
262293
linear_proj=backend.row_parallel_linear(),
@@ -268,6 +299,13 @@ def get_gpt_layer_with_transformer_engine_spec(
268299
pre_mlp_layernorm=backend.layer_norm() if num_experts else IdentityOp,
269300
mlp=mlp,
270301
mlp_bda=get_bias_dropout_add,
302+
sharded_state_dict_keys_map=(
303+
{
304+
"self_attention.linear_qkv_down_proj.layer_norm_": "input_layernorm.",
305+
}
306+
if fuse_input_layernorm
307+
else {}
308+
),
271309
),
272310
)
273311
else:

0 commit comments

Comments
 (0)