Skip to content

Commit 2832094

Browse files
committed
WIP
1 parent bbbedbb commit 2832094

File tree

2 files changed

+280
-64
lines changed

2 files changed

+280
-64
lines changed

megatron/core/models/gpt/gpt_layer_specs.py

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,13 @@ def get_gpt_layer_with_inference_spec(
171171
)
172172

173173

174+
import os
175+
# NOTE: These toggles are intentionally environment-driven for experimentation.
176+
# They are expected to be replaced by explicit config options.
177+
FUSE_INPUT_LAYERNORM = os.environ.get("FUSE_INPUT_LAYERNORM", "1") == "1"
178+
FUSE_QKV_DOWN_PROJ = os.environ.get("FUSE_QKV_DOWN_PROJ", "0") == "1"
179+
180+
174181
def get_gpt_layer_with_transformer_engine_spec(
175182
num_experts: Optional[int] = None,
176183
moe_grouped_gemm: Optional[bool] = False,
@@ -235,6 +242,31 @@ def get_gpt_layer_with_transformer_engine_spec(
235242

236243
if multi_latent_attention:
237244
assert qk_l2_norm is False, "qk_l2_norm is not supported with MLA."
245+
246+
fuse_input_layernorm = (
247+
FUSE_INPUT_LAYERNORM
248+
and FUSE_QKV_DOWN_PROJ
249+
and backend.column_parallel_layer_norm_linear() is not None
250+
)
251+
252+
input_layernorm = (
253+
IdentityOp
254+
if fuse_input_layernorm
255+
else backend.layer_norm()
256+
)
257+
258+
linear_q_down_proj = (
259+
backend.column_parallel_layer_norm_linear()
260+
if fuse_input_layernorm
261+
else backend.linear()
262+
)
263+
264+
linear_kv_down_proj = (
265+
backend.column_parallel_layer_norm_linear()
266+
if fuse_input_layernorm
267+
else backend.linear()
268+
)
269+
238270
linear_q_up_proj = (
239271
backend.column_parallel_layer_norm_linear()
240272
if qk_layernorm
@@ -248,15 +280,15 @@ def get_gpt_layer_with_transformer_engine_spec(
248280
return ModuleSpec(
249281
module=TransformerLayer,
250282
submodules=TransformerLayerSubmodules(
251-
input_layernorm=backend.layer_norm(),
283+
input_layernorm=input_layernorm,
252284
self_attention=ModuleSpec(
253285
module=MLASelfAttention,
254286
params={"attn_mask_type": AttnMaskType.causal},
255287
submodules=MLASelfAttentionSubmodules(
256288
linear_q_proj=backend.column_parallel_linear(),
257-
linear_q_down_proj=backend.linear(),
289+
linear_q_down_proj=linear_q_down_proj,
258290
linear_q_up_proj=linear_q_up_proj,
259-
linear_kv_down_proj=backend.linear(),
291+
linear_kv_down_proj=linear_kv_down_proj,
260292
linear_kv_up_proj=linear_kv_up_proj,
261293
core_attention=backend.core_attention(),
262294
linear_proj=backend.row_parallel_linear(),
@@ -268,6 +300,13 @@ def get_gpt_layer_with_transformer_engine_spec(
268300
pre_mlp_layernorm=backend.layer_norm() if num_experts else IdentityOp,
269301
mlp=mlp,
270302
mlp_bda=get_bias_dropout_add,
303+
sharded_state_dict_keys_map=(
304+
{
305+
"self_attention.linear_qkv_down_proj.layer_norm_": "input_layernorm.",
306+
}
307+
if fuse_input_layernorm
308+
else {}
309+
),
271310
),
272311
)
273312
else:

0 commit comments

Comments
 (0)