Skip to content

Commit 6e2153b

Browse files
[dev] feat(moe): Cherry-pick #1989 back to dev (#3011)
Signed-off-by: oliver könig <[email protected]> Co-authored-by: oliver könig <[email protected]>
1 parent ac9f665 commit 6e2153b

File tree

13 files changed

+749
-503
lines changed

13 files changed

+749
-503
lines changed

gpt_builders.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010
get_gpt_decoder_layer_specs,
1111
)
1212
from megatron.core.models.gpt.experimental_attention_variant_module_specs import (
13-
is_linear_attention_variant,
13+
get_transformer_block_with_experimental_attention_variant_spec,
14+
get_transformer_layer_with_experimental_attention_variant_spec,
1415
)
1516
from megatron.core.models.gpt.heterogeneous.heterogeneous_layer_specs import (
1617
get_gpt_heterogeneous_layer_spec,
@@ -46,7 +47,13 @@ def gpt_builder(args, pre_process, post_process, vp_stage=None, config=None, pg_
4647
else:
4748
use_te = args.transformer_impl == "transformer_engine"
4849

49-
if args.num_experts or is_linear_attention_variant(args.experimental_attention_variant):
50+
if args.experimental_attention_variant is not None:
51+
transformer_layer_spec = (
52+
get_transformer_block_with_experimental_attention_variant_spec(
53+
config=config, vp_stage=vp_stage
54+
)
55+
)
56+
elif args.num_experts:
5057
assert not (config.transformer_impl == "inference_optimized")
5158
# Define the decoder block spec
5259
transformer_layer_spec = get_gpt_decoder_block_spec(
@@ -70,9 +77,19 @@ def gpt_builder(args, pre_process, post_process, vp_stage=None, config=None, pg_
7077
mtp_transformer_layer_spec = import_module(args.spec)
7178
else:
7279
# Define the decoder block spec
73-
decoder_layer_specs = get_gpt_decoder_layer_specs(
74-
config, use_transformer_engine=use_te, normalization=args.normalization, qk_l2_norm=args.qk_l2_norm, vp_stage=vp_stage
75-
)
80+
if args.experimental_attention_variant is not None:
81+
decoder_layer_specs = (
82+
get_transformer_layer_with_experimental_attention_variant_spec(
83+
config=config
84+
)
85+
)
86+
else:
87+
decoder_layer_specs = get_gpt_decoder_layer_specs(
88+
config,
89+
use_transformer_engine=use_te,
90+
normalization=args.normalization,
91+
qk_l2_norm=args.qk_l2_norm,
92+
)
7693
mtp_transformer_layer_spec = decoder_layer_specs[-1]
7794
# Use spec of the last layer in decoder block as spec of the transformer layer in MTP
7895
mtp_block_spec = get_gpt_mtp_block_spec(

0 commit comments

Comments
 (0)