Skip to content

Commit f561063

Browse files
committed
WIP
1 parent bbbedbb commit f561063

File tree

2 files changed

+281
-64
lines changed

2 files changed

+281
-64
lines changed

megatron/core/models/gpt/gpt_layer_specs.py

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,14 @@ def get_gpt_layer_with_inference_spec(
171171
)
172172

173173

174+
import os
175+
# NOTE: These toggles are intentionally environment-driven for experimentation.
176+
# They are expected to be replaced by explicit config options.
177+
FUSE_INPUT_LAYERNORM = os.environ.get("FUSE_INPUT_LAYERNORM", "1") == "1"
178+
FUSE_QKV_DOWN_PROJ = os.environ.get("FUSE_QKV_DOWN_PROJ", "0") == "1"
179+
180+
181+
174182
def get_gpt_layer_with_transformer_engine_spec(
175183
num_experts: Optional[int] = None,
176184
moe_grouped_gemm: Optional[bool] = False,
@@ -235,6 +243,31 @@ def get_gpt_layer_with_transformer_engine_spec(
235243

236244
if multi_latent_attention:
237245
assert qk_l2_norm is False, "qk_l2_norm is not supported with MLA."
246+
247+
fuse_input_layernorm = (
248+
FUSE_INPUT_LAYERNORM
249+
and FUSE_QKV_DOWN_PROJ
250+
and backend.column_parallel_layer_norm_linear() is not None
251+
)
252+
253+
input_layernorm = (
254+
IdentityOp
255+
if fuse_input_layernorm
256+
else backend.layer_norm()
257+
)
258+
259+
linear_q_down_proj = (
260+
backend.column_parallel_layer_norm_linear()
261+
if fuse_input_layernorm
262+
else backend.linear()
263+
)
264+
265+
linear_kv_down_proj = (
266+
backend.column_parallel_layer_norm_linear()
267+
if fuse_input_layernorm
268+
else backend.linear()
269+
)
270+
238271
linear_q_up_proj = (
239272
backend.column_parallel_layer_norm_linear()
240273
if qk_layernorm
@@ -248,15 +281,15 @@ def get_gpt_layer_with_transformer_engine_spec(
248281
return ModuleSpec(
249282
module=TransformerLayer,
250283
submodules=TransformerLayerSubmodules(
251-
input_layernorm=backend.layer_norm(),
284+
input_layernorm=input_layernorm,
252285
self_attention=ModuleSpec(
253286
module=MLASelfAttention,
254287
params={"attn_mask_type": AttnMaskType.causal},
255288
submodules=MLASelfAttentionSubmodules(
256289
linear_q_proj=backend.column_parallel_linear(),
257-
linear_q_down_proj=backend.linear(),
290+
linear_q_down_proj=linear_q_down_proj,
258291
linear_q_up_proj=linear_q_up_proj,
259-
linear_kv_down_proj=backend.linear(),
292+
linear_kv_down_proj=linear_kv_down_proj,
260293
linear_kv_up_proj=linear_kv_up_proj,
261294
core_attention=backend.core_attention(),
262295
linear_proj=backend.row_parallel_linear(),
@@ -268,6 +301,13 @@ def get_gpt_layer_with_transformer_engine_spec(
268301
pre_mlp_layernorm=backend.layer_norm() if num_experts else IdentityOp,
269302
mlp=mlp,
270303
mlp_bda=get_bias_dropout_add,
304+
sharded_state_dict_keys_map=(
305+
{
306+
"self_attention.linear_qkv_down_proj.layer_norm_": "input_layernorm.",
307+
}
308+
if fuse_input_layernorm
309+
else {}
310+
),
271311
),
272312
)
273313
else:

0 commit comments

Comments
 (0)