Skip to content

Commit 065a890

Browse files
Fix Qwen3-Next hang on Blackwell, add a flag to control torch.compile (#2058)
Co-authored-by: Dennis(Zhenhuan) Liu <[email protected]>
1 parent fe8713a commit 065a890

File tree

6 files changed

+41
-13
lines changed

6 files changed

+41
-13
lines changed

megatron/core/jit.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,27 @@
77
jit_fuser = torch.jit.script
88
# nvFuser is deprecated in PyTorch JIT starting from 2.2
99

10-
try:
11-
if is_torch_min_version("2.2.0a0"):
12-
jit_fuser = torch.compile
13-
except ImportError:
1410

15-
def noop_decorator(func):
16-
return func
11+
def noop_decorator(func):
12+
'''No-op decorator'''
13+
return func
1714

15+
16+
def enable_jit_fuser():
17+
'''Enable the JIT fuser'''
18+
global jit_fuser
19+
try:
20+
if is_torch_min_version("2.2.0a0"):
21+
jit_fuser = torch.compile
22+
except ImportError:
23+
24+
jit_fuser = noop_decorator
25+
26+
27+
def disable_jit_fuser():
28+
'''Disable the JIT fuser'''
29+
global jit_fuser
1830
jit_fuser = noop_decorator
31+
32+
33+
enable_jit_fuser()

megatron/core/ssm/gated_delta_net.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from megatron.core.dist_checkpointing.mapping import ReplicaId, ShardedTensorFactory
1919
from megatron.core.fp8_utils import get_fp8_align_size
2020
from megatron.core.inference.contexts import BaseInferenceContext
21+
from megatron.core.jit import jit_fuser
2122
from megatron.core.packed_seq_params import PackedSeqParams
2223
from megatron.core.process_groups_config import ProcessGroupCollection
2324
from megatron.core.tensor_parallel import get_cuda_rng_tracker
@@ -384,7 +385,7 @@ def forward(
384385

385386
# RMSNorm
386387
nvtx_range_push(suffix="gated_norm")
387-
norm_out = self._torch_compiled_gated_norm(core_attn_out, gate)
388+
norm_out = self._apply_gated_norm(core_attn_out, gate)
388389
nvtx_range_pop(suffix="gated_norm")
389390

390391
# Transpose: b s x --> s b x
@@ -399,8 +400,8 @@ def forward(
399400

400401
return out, out_bias
401402

402-
@torch.compile
403-
def _torch_compiled_gated_norm(self, x, gate):
403+
@jit_fuser
404+
def _apply_gated_norm(self, x, gate):
404405
# Output Norm
405406
x_dtype = x.dtype
406407
x = x.reshape(-1, x.shape[-1])

megatron/core/transformer/attention.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from megatron.core import tensor_parallel
1111
from megatron.core.inference.contexts import BaseInferenceContext
12+
from megatron.core.jit import jit_fuser
1213
from megatron.core.models.common.embeddings.rope_utils import (
1314
apply_rotary_pos_emb,
1415
apply_rotary_pos_emb_with_cos_sin,
@@ -936,7 +937,7 @@ def forward(
936937
# Output gate
937938
if gate is not None:
938939
nvtx_range_push(suffix="output_gate")
939-
core_attn_out = self._torch_compiled_output_gate(core_attn_out, gate)
940+
core_attn_out = self._apply_output_gate(core_attn_out, gate)
940941
nvtx_range_pop(suffix="output_gate")
941942

942943
# =================
@@ -949,8 +950,8 @@ def forward(
949950

950951
return output, bias
951952

952-
@torch.compile
953-
def _torch_compiled_output_gate(self, x, gate):
953+
@jit_fuser
954+
def _apply_output_gate(self, x, gate):
954955
x_dtype = x.dtype
955956
gate = gate.contiguous()
956957
gate = gate.view(*x.shape)

megatron/core/transformer/transformer_block.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -764,6 +764,12 @@ def sharded_state_dict(
764764
elif isinstance(self.config.moe_layer_freq, list):
765765
non_homogeneous_layers = True
766766

767+
if isinstance(self.config.linear_attention_freq, int):
768+
if self.config.linear_attention_freq > 1:
769+
non_homogeneous_layers = True
770+
elif isinstance(self.config.linear_attention_freq, list):
771+
non_homogeneous_layers = True
772+
767773
if self.config.heterogeneous_block_specs:
768774
non_homogeneous_layers = True
769775

megatron/training/arguments.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2315,7 +2315,8 @@ def _add_training_args(parser):
23152315
help='The communicator group names to use high priority streams.')
23162316
group.add_argument('--use-te-activation-func', action='store_true',
23172317
help='Use activation function kernel from Transformer Engine in MLP module.')
2318-
2318+
group.add_argument('--disable-jit-fuser', action='store_true',
2319+
help='Disable the JIT fuser.')
23192320
return parser
23202321

23212322

megatron/training/global_vars.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from megatron.core import Timers
1010
from megatron.core.config import set_experimental_flag
1111
from megatron.core.energy_monitor import EnergyMonitor
12+
from megatron.core.jit import disable_jit_fuser
1213
from megatron.core.num_microbatches_calculator import init_num_microbatches_calculator, unset_num_microbatches_calculator
1314
from megatron.training import dist_signal_handler
1415
from megatron.training.tokenizer import build_tokenizer
@@ -111,6 +112,9 @@ def set_global_variables(args, build_tokenizer=True):
111112
if args.exit_signal_handler:
112113
_set_signal_handler()
113114

115+
if args.disable_jit_fuser:
116+
disable_jit_fuser()
117+
114118

115119
def unset_global_variables():
116120
"""Unset global vars.

0 commit comments

Comments
 (0)