Skip to content

Commit 674fec3

Browse files
yuzhongw-nvidiako3n1g
authored andcommitted
[main] feat(moe): Support gated delta net for Qwen3-Next (1/4) (NVIDIA#1989)
Signed-off-by: oliver könig <[email protected]> Co-authored-by: oliver könig <[email protected]>
1 parent 71bb0fd commit 674fec3

File tree

21 files changed

+2463
-54
lines changed

21 files changed

+2463
-54
lines changed

gpt_builders.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@
88
get_gpt_layer_with_inference_spec,
99
get_gpt_mtp_block_spec,
1010
)
11+
from megatron.core.models.gpt.experimental_attention_variant_module_specs import (
12+
get_transformer_block_with_experimental_attention_variant_spec,
13+
)
1114
from megatron.core.models.gpt.heterogeneous.heterogeneous_layer_specs import (
1215
get_gpt_heterogeneous_layer_spec,
1316
)
@@ -42,7 +45,13 @@ def gpt_builder(args, pre_process, post_process, vp_stage=None, config=None, pg_
4245
else:
4346
use_te = args.transformer_impl == "transformer_engine"
4447

45-
if args.num_experts:
48+
if args.experimental_attention_variant is not None:
49+
transformer_layer_spec = (
50+
get_transformer_block_with_experimental_attention_variant_spec(
51+
config=config, vp_stage=vp_stage
52+
)
53+
)
54+
elif args.num_experts:
4655
assert not (config.transformer_impl == "inference_optimized")
4756
# Define the decoder block spec
4857
transformer_layer_spec = get_gpt_decoder_block_spec(

megatron/core/jit.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,27 @@
77
jit_fuser = torch.jit.script
88
# nvFuser is deprecated in PyTorch JIT starting from 2.2
99

10-
try:
11-
if is_torch_min_version("2.2.0a0"):
12-
jit_fuser = torch.compile
13-
except ImportError:
1410

15-
def noop_decorator(func):
16-
return func
11+
def noop_decorator(func):
12+
'''No-op decorator'''
13+
return func
1714

15+
16+
def enable_jit_fuser():
17+
'''Enable the JIT fuser'''
18+
global jit_fuser
19+
try:
20+
if is_torch_min_version("2.2.0a0"):
21+
jit_fuser = torch.compile
22+
except ImportError:
23+
24+
jit_fuser = noop_decorator
25+
26+
27+
def disable_jit_fuser():
28+
'''Disable the JIT fuser'''
29+
global jit_fuser
1830
jit_fuser = noop_decorator
31+
32+
33+
enable_jit_fuser()

0 commit comments

Comments
 (0)