Skip to content

Commit 3b17f91

Browse files
yuzhongw-nvidiaVictarry
authored andcommitted
[Dev] Fix Qwen3-Next hang on Blackwell, add a flag to control torch.compile (NVIDIA#2058)
Co-authored-by: Dennis(Zhenhuan) Liu <denliu@nvidia.com>
1 parent c69e6c3 commit 3b17f91

File tree

6 files changed

+41
-13
lines changed

6 files changed

+41
-13
lines changed

megatron/core/jit.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,27 @@
77
jit_fuser = torch.jit.script
88
# nvFuser is deprecated in PyTorch JIT starting from 2.2
99

10-
try:
11-
if is_torch_min_version("2.2.0a0"):
12-
jit_fuser = torch.compile
13-
except ImportError:
1410

15-
def noop_decorator(func):
16-
return func
11+
def noop_decorator(func):
12+
'''No-op decorator'''
13+
return func
1714

15+
16+
def enable_jit_fuser():
17+
'''Enable the JIT fuser'''
18+
global jit_fuser
19+
try:
20+
if is_torch_min_version("2.2.0a0"):
21+
jit_fuser = torch.compile
22+
except ImportError:
23+
24+
jit_fuser = noop_decorator
25+
26+
27+
def disable_jit_fuser():
28+
'''Disable the JIT fuser'''
29+
global jit_fuser
1830
jit_fuser = noop_decorator
31+
32+
33+
enable_jit_fuser()

megatron/core/ssm/gated_delta_net.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from megatron.core.dist_checkpointing.mapping import ReplicaId, ShardedTensorFactory
1919
from megatron.core.fp8_utils import get_fp8_align_size
2020
from megatron.core.inference.contexts import BaseInferenceContext
21+
from megatron.core.jit import jit_fuser
2122
from megatron.core.packed_seq_params import PackedSeqParams
2223
from megatron.core.process_groups_config import ProcessGroupCollection
2324
from megatron.core.tensor_parallel import get_cuda_rng_tracker
@@ -384,7 +385,7 @@ def forward(
384385

385386
# RMSNorm
386387
nvtx_range_push(suffix="gated_norm")
387-
norm_out = self._torch_compiled_gated_norm(core_attn_out, gate)
388+
norm_out = self._apply_gated_norm(core_attn_out, gate)
388389
nvtx_range_pop(suffix="gated_norm")
389390

390391
# Transpose: b s x --> s b x
@@ -399,8 +400,8 @@ def forward(
399400

400401
return out, out_bias
401402

402-
@torch.compile
403-
def _torch_compiled_gated_norm(self, x, gate):
403+
@jit_fuser
404+
def _apply_gated_norm(self, x, gate):
404405
# Output Norm
405406
x_dtype = x.dtype
406407
x = x.reshape(-1, x.shape[-1])

megatron/core/transformer/attention.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from megatron.core import tensor_parallel
1111
from megatron.core.inference.contexts import BaseInferenceContext
12+
from megatron.core.jit import jit_fuser
1213
from megatron.core.models.common.embeddings.rope_utils import (
1314
apply_rotary_pos_emb,
1415
apply_rotary_pos_emb_with_cos_sin,
@@ -923,7 +924,7 @@ def forward(
923924
# Output gate
924925
if gate is not None:
925926
nvtx_range_push(suffix="output_gate")
926-
core_attn_out = self._torch_compiled_output_gate(core_attn_out, gate)
927+
core_attn_out = self._apply_output_gate(core_attn_out, gate)
927928
nvtx_range_pop(suffix="output_gate")
928929

929930
# =================
@@ -936,8 +937,8 @@ def forward(
936937

937938
return output, bias
938939

939-
@torch.compile
940-
def _torch_compiled_output_gate(self, x, gate):
940+
@jit_fuser
941+
def _apply_output_gate(self, x, gate):
941942
x_dtype = x.dtype
942943
gate = gate.contiguous()
943944
gate = gate.view(*x.shape)

megatron/core/transformer/transformer_block.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -793,6 +793,12 @@ def sharded_state_dict(
793793
elif isinstance(self.config.moe_layer_freq, list):
794794
non_homogeneous_layers = True
795795

796+
if isinstance(self.config.linear_attention_freq, int):
797+
if self.config.linear_attention_freq > 1:
798+
non_homogeneous_layers = True
799+
elif isinstance(self.config.linear_attention_freq, list):
800+
non_homogeneous_layers = True
801+
796802
if self.config.heterogeneous_block_specs:
797803
non_homogeneous_layers = True
798804

megatron/training/arguments.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2337,7 +2337,8 @@ def _add_training_args(parser):
23372337
help='The communicator group names to use high priority streams.')
23382338
group.add_argument('--use-te-activation-func', action='store_true',
23392339
help='Use activation function kernel from Transformer Engine in MLP module.')
2340-
2340+
group.add_argument('--disable-jit-fuser', action='store_true',
2341+
help='Disable the JIT fuser.')
23412342
return parser
23422343

23432344

megatron/training/global_vars.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from megatron.core import Timers
1010
from megatron.core.config import set_experimental_flag
1111
from megatron.core.energy_monitor import EnergyMonitor
12+
from megatron.core.jit import disable_jit_fuser
1213
from megatron.core.num_microbatches_calculator import init_num_microbatches_calculator, unset_num_microbatches_calculator
1314
from megatron.training import dist_signal_handler
1415
from megatron.training.tokenizer import build_tokenizer
@@ -111,6 +112,9 @@ def set_global_variables(args, build_tokenizer=True):
111112
if args.exit_signal_handler:
112113
_set_signal_handler()
113114

115+
if args.disable_jit_fuser:
116+
disable_jit_fuser()
117+
114118

115119
def unset_global_variables():
116120
"""Unset global vars.

0 commit comments

Comments
 (0)