diff --git a/src/megatron/bridge/models/qwen/qwen_provider.py b/src/megatron/bridge/models/qwen/qwen_provider.py index f211c0c84c..2a7b6fd735 100644 --- a/src/megatron/bridge/models/qwen/qwen_provider.py +++ b/src/megatron/bridge/models/qwen/qwen_provider.py @@ -19,7 +19,9 @@ import torch import torch.nn.functional as F -from megatron.core.models.gpt.gpt_layer_specs import get_gpt_decoder_block_spec +from megatron.core.models.gpt.experimental_attention_variant_module_specs import ( + get_transformer_block_with_experimental_attention_variant_spec, +) from megatron.core.transformer.spec_utils import ModuleSpec from megatron.bridge.models.gpt_provider import GPTModelProvider @@ -433,8 +435,8 @@ class Qwen3MoEModelProvider235B_A22B(Qwen3MoEModelProvider): class Qwen3NextModelProvider(Qwen3MoEModelProvider): """Base provider for Qwen 3 Next Models.""" - transformer_layer_spec: ModuleSpec | Callable[["GPTModelProvider"], ModuleSpec] = partial( - get_gpt_decoder_block_spec, use_transformer_engine=HAVE_TE + transformer_layer_spec: ModuleSpec | Callable[["GPTModelProvider"], ModuleSpec] = ( + get_transformer_block_with_experimental_attention_variant_spec ) layernorm_zero_centered_gamma: bool = True # Zero-centered RMSNorm diff --git a/src/megatron/bridge/training/mlm_compat/model.py b/src/megatron/bridge/training/mlm_compat/model.py index 7bd088c44c..e32153956d 100644 --- a/src/megatron/bridge/training/mlm_compat/model.py +++ b/src/megatron/bridge/training/mlm_compat/model.py @@ -21,6 +21,9 @@ from megatron.core.enums import ModelType from megatron.core.fp8_utils import correct_amax_history_if_needed from megatron.core.models.gpt import GPTModel +from megatron.core.models.gpt.experimental_attention_variant_module_specs import ( + get_transformer_block_with_experimental_attention_variant_spec, +) from megatron.core.models.gpt.gpt_layer_specs import ( get_gpt_decoder_block_spec, get_gpt_layer_local_spec, @@ -90,7 +93,13 @@ def _gpt_provider( if config is None: config = _transformer_config_from_args(args) - if args.num_experts: + if args.experimental_attention_variant is not None: + transformer_layer_spec = ( + get_transformer_block_with_experimental_attention_variant_spec( + config=config, vp_stage=vp_stage + ) + ) + elif args.num_experts: # Define the decoder block spec transformer_layer_spec = get_gpt_decoder_block_spec( config,