1010 get_gpt_decoder_layer_specs ,
1111)
1212from megatron .core .models .gpt .experimental_attention_variant_module_specs import (
13- is_linear_attention_variant ,
13+ get_transformer_block_with_experimental_attention_variant_spec ,
14+ get_transformer_layer_with_experimental_attention_variant_spec ,
1415)
1516from megatron .core .models .gpt .heterogeneous .heterogeneous_layer_specs import (
1617 get_gpt_heterogeneous_layer_spec ,
@@ -46,7 +47,13 @@ def gpt_builder(args, pre_process, post_process, vp_stage=None, config=None, pg_
4647 else :
4748 use_te = args .transformer_impl == "transformer_engine"
4849
49- if args .num_experts or is_linear_attention_variant (args .experimental_attention_variant ):
50+ if args .experimental_attention_variant is not None :
51+ transformer_layer_spec = (
52+ get_transformer_block_with_experimental_attention_variant_spec (
53+ config = config , vp_stage = vp_stage
54+ )
55+ )
56+ elif args .num_experts :
5057 assert not (config .transformer_impl == "inference_optimized" )
5158 # Define the decoder block spec
5259 transformer_layer_spec = get_gpt_decoder_block_spec (
@@ -70,9 +77,19 @@ def gpt_builder(args, pre_process, post_process, vp_stage=None, config=None, pg_
7077 mtp_transformer_layer_spec = import_module (args .spec )
7178 else :
7279 # Define the decoder block spec
73- decoder_layer_specs = get_gpt_decoder_layer_specs (
74- config , use_transformer_engine = use_te , normalization = args .normalization , qk_l2_norm = args .qk_l2_norm , vp_stage = vp_stage
75- )
80+ if args .experimental_attention_variant is not None :
81+ decoder_layer_specs = (
82+ get_transformer_layer_with_experimental_attention_variant_spec (
83+ config = config
84+ )
85+ )
86+ else :
87+ decoder_layer_specs = get_gpt_decoder_layer_specs (
88+ config ,
89+ use_transformer_engine = use_te ,
90+ normalization = args .normalization ,
91+ qk_l2_norm = args .qk_l2_norm ,
92+ )
7693 mtp_transformer_layer_spec = decoder_layer_specs [- 1 ]
7794 # Use spec of the last layer in decoder block as spec of the transformer layer in MTP
7895 mtp_block_spec = get_gpt_mtp_block_spec (
0 commit comments