Skip to content

Commit 99fab2c

Browse files
authored
[Bugfix] Fix Mistral Large 3 NVFP4 TRTLLM MoE (sgl-project#18065)
1 parent a45647b commit 99fab2c

File tree

2 files changed

+115
-111
lines changed

2 files changed

+115
-111
lines changed

python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py

Lines changed: 94 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -461,123 +461,114 @@ def apply(
461461
dispatch_output: StandardDispatchOutput,
462462
) -> CombineInput:
463463

464-
from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4
465464
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
466465

467466
x = dispatch_output.hidden_states
468467
topk_output = dispatch_output.topk_output
469-
topk_weights, topk_ids = topk_output.topk_weights, topk_output.topk_ids
470-
471-
output = cutlass_moe_fp4(
472-
a=x,
473-
a1_gscale=layer.w13_input_scale_quant,
474-
w1_fp4=layer.w13_weight,
475-
w1_blockscale=layer.w13_weight_scale,
476-
w1_alphas=layer.g1_alphas,
477-
a2_gscale=layer.w2_input_scale_quant,
478-
w2_fp4=layer.w2_weight,
479-
w2_blockscale=layer.w2_weight_scale,
480-
w2_alphas=layer.g2_alphas,
481-
topk_weights=topk_weights,
482-
topk_ids=topk_ids,
483-
params=layer.cutlass_moe_params,
484-
apply_router_weight_on_input=self.moe_runner_config.apply_router_weight_on_input,
485-
).to(x.dtype)
486468

487-
return StandardCombineInput(hidden_states=output)
488-
489-
def apply_with_router_logits(
490-
self,
491-
layer: torch.nn.Module,
492-
dispatch_output: StandardDispatchOutput,
493-
) -> torch.Tensor:
494-
assert self.use_flashinfer_trtllm
495-
496-
x = dispatch_output.hidden_states
497-
topk_output = dispatch_output.topk_output
498-
499-
from flashinfer import fp4_quantize, trtllm_fp4_block_scale_moe
469+
if self.use_flashinfer_trtllm:
470+
from flashinfer import fp4_quantize, trtllm_fp4_block_scale_moe
500471

501-
from sglang.srt.layers.moe.utils import RoutingMethodType
472+
router_logits = topk_output.router_logits
473+
topk_config = topk_output.topk_config
502474

503-
router_logits = topk_output.router_logits
504-
topk_config = topk_output.topk_config
475+
# Quantize input hidden states using fp4_quantize
476+
hs_fp4_bytes, hs_sf_bytes = fp4_quantize(
477+
x,
478+
layer.w13_input_scale_quant,
479+
self.group_size, # sf_vec_size
480+
False, # use_ue8m0
481+
False, # is_sf_swizzled_layout
482+
)
483+
hs_fp4 = hs_fp4_bytes.reshape(x.shape[0], x.shape[1] // 2)
484+
hs_scale = hs_sf_bytes.view(torch.float8_e4m3fn).reshape(-1)
505485

506-
# Quantize input hidden states using fp4_quantize
507-
hs_fp4_bytes, hs_sf_bytes = fp4_quantize(
508-
x,
509-
layer.w13_input_scale_quant,
510-
self.group_size, # sf_vec_size
511-
False, # use_ue8m0
512-
False, # is_sf_swizzled_layout
513-
)
514-
hs_fp4 = hs_fp4_bytes.reshape(x.shape[0], x.shape[1] // 2)
515-
hs_scale = hs_sf_bytes.view(torch.float8_e4m3fn).reshape(-1)
486+
correction_bias = (
487+
None
488+
if topk_config.correction_bias is None
489+
else topk_config.correction_bias.to(x.dtype)
490+
)
516491

517-
correction_bias = (
518-
None
519-
if topk_config.correction_bias is None
520-
else topk_config.correction_bias.to(x.dtype)
521-
)
492+
assert layer.routing_method_type is not None
522493

523-
assert layer.routing_method_type is not None
494+
# DeepSeekV3 style routing requires float32 router logits
495+
if layer.routing_method_type == RoutingMethodType.DeepSeekV3:
496+
router_logits = router_logits.to(torch.float32)
524497

525-
# DeepSeekV3 style routing requires float32 router logits
526-
if layer.routing_method_type == RoutingMethodType.DeepSeekV3:
527-
router_logits = router_logits.to(torch.float32)
498+
routed_scaling_factor = self.moe_runner_config.routed_scaling_factor
499+
routed_scaling_factor = (
500+
routed_scaling_factor if routed_scaling_factor is not None else 1.0
501+
)
528502

529-
routed_scaling_factor = self.moe_runner_config.routed_scaling_factor
530-
routed_scaling_factor = (
531-
routed_scaling_factor if routed_scaling_factor is not None else 1.0
532-
)
503+
with use_symmetric_memory(
504+
get_tp_group(), disabled=not is_allocation_symmetric()
505+
):
506+
num_tokens = hs_fp4.shape[0]
507+
hidden_size = (
508+
hs_fp4.shape[-1] * 2
509+
if hs_fp4.dtype == torch.uint8
510+
else hs_fp4.shape[-1]
511+
)
512+
symm_output = torch.empty(
513+
num_tokens, hidden_size, dtype=torch.bfloat16, device=hs_fp4.device
514+
)
533515

534-
with use_symmetric_memory(
535-
get_tp_group(), disabled=not is_allocation_symmetric()
536-
):
537-
num_tokens = hs_fp4.shape[0]
538-
hidden_size = (
539-
hs_fp4.shape[-1] * 2
540-
if hs_fp4.dtype == torch.uint8
541-
else hs_fp4.shape[-1]
542-
)
543-
symm_output = torch.empty(
544-
num_tokens, hidden_size, dtype=torch.bfloat16, device=hs_fp4.device
545-
)
516+
output = trtllm_fp4_block_scale_moe(
517+
routing_logits=router_logits,
518+
routing_bias=correction_bias,
519+
hidden_states=hs_fp4,
520+
hidden_states_scale=hs_scale,
521+
gemm1_weights=layer.gemm1_weights_fp4_shuffled,
522+
gemm1_weights_scale=layer.gemm1_scales_fp4_shuffled.view(
523+
torch.float8_e4m3fn
524+
),
525+
gemm1_bias=None,
526+
gemm1_alpha=None,
527+
gemm1_beta=None,
528+
gemm1_clamp_limit=None,
529+
gemm2_weights=layer.gemm2_weights_fp4_shuffled,
530+
gemm2_weights_scale=layer.gemm2_scales_fp4_shuffled.view(
531+
torch.float8_e4m3fn
532+
),
533+
gemm2_bias=None,
534+
output1_scale_scalar=layer.g1_scale_c,
535+
output1_scale_gate_scalar=layer.g1_alphas,
536+
output2_scale_scalar=layer.g2_alphas,
537+
num_experts=layer.num_experts,
538+
top_k=topk_config.top_k,
539+
n_group=topk_config.num_expert_group,
540+
topk_group=topk_config.topk_group,
541+
intermediate_size=layer.intermediate_size_per_partition,
542+
local_expert_offset=layer.moe_ep_rank * layer.num_local_experts,
543+
local_num_experts=layer.num_local_experts,
544+
routed_scaling_factor=routed_scaling_factor,
545+
routing_method_type=layer.routing_method_type,
546+
do_finalize=True,
547+
tune_max_num_tokens=next_power_of_2(hs_fp4.shape[0]),
548+
output=symm_output,
549+
)[0]
550+
else:
551+
from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4
552+
553+
topk_weights, topk_ids = topk_output.topk_weights, topk_output.topk_ids
554+
555+
output = cutlass_moe_fp4(
556+
a=x,
557+
a1_gscale=layer.w13_input_scale_quant,
558+
w1_fp4=layer.w13_weight,
559+
w1_blockscale=layer.w13_weight_scale,
560+
w1_alphas=layer.g1_alphas,
561+
a2_gscale=layer.w2_input_scale_quant,
562+
w2_fp4=layer.w2_weight,
563+
w2_blockscale=layer.w2_weight_scale,
564+
w2_alphas=layer.g2_alphas,
565+
topk_weights=topk_weights,
566+
topk_ids=topk_ids,
567+
params=layer.cutlass_moe_params,
568+
apply_router_weight_on_input=self.moe_runner_config.apply_router_weight_on_input,
569+
).to(x.dtype)
546570

547-
return trtllm_fp4_block_scale_moe(
548-
routing_logits=router_logits,
549-
routing_bias=correction_bias,
550-
hidden_states=hs_fp4,
551-
hidden_states_scale=hs_scale,
552-
gemm1_weights=layer.gemm1_weights_fp4_shuffled,
553-
gemm1_weights_scale=layer.gemm1_scales_fp4_shuffled.view(
554-
torch.float8_e4m3fn
555-
),
556-
gemm1_bias=None,
557-
gemm1_alpha=None,
558-
gemm1_beta=None,
559-
gemm1_clamp_limit=None,
560-
gemm2_weights=layer.gemm2_weights_fp4_shuffled,
561-
gemm2_weights_scale=layer.gemm2_scales_fp4_shuffled.view(
562-
torch.float8_e4m3fn
563-
),
564-
gemm2_bias=None,
565-
output1_scale_scalar=layer.g1_scale_c,
566-
output1_scale_gate_scalar=layer.g1_alphas,
567-
output2_scale_scalar=layer.g2_alphas,
568-
num_experts=layer.num_experts,
569-
top_k=topk_config.top_k,
570-
n_group=topk_config.num_expert_group,
571-
topk_group=topk_config.topk_group,
572-
intermediate_size=layer.intermediate_size_per_partition,
573-
local_expert_offset=layer.moe_ep_rank * layer.num_local_experts,
574-
local_num_experts=layer.num_local_experts,
575-
routed_scaling_factor=routed_scaling_factor,
576-
routing_method_type=layer.routing_method_type,
577-
do_finalize=True,
578-
tune_max_num_tokens=next_power_of_2(hs_fp4.shape[0]),
579-
output=symm_output,
580-
)[0]
571+
return StandardCombineInput(hidden_states=output)
581572

582573

583574
class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):

test/registered/8-gpu-models/test_mistral_large3.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,19 +9,21 @@
99

1010
# Runs on both H200 and B200 via nightly-8-gpu-common suite
1111
# Note: trtllm_mla backend may have hardware-specific behavior
12-
register_cuda_ci(est_time=1800, suite="nightly-8-gpu-common", nightly=True)
12+
register_cuda_ci(est_time=3000, suite="nightly-8-gpu-common", nightly=True)
1313

14-
MISTRAL_LARGE3_MODEL_PATH = "mistralai/Mistral-Large-3-675B-Instruct-2512"
14+
MISTRAL_LARGE3_FP8_MODEL_PATH = "mistralai/Mistral-Large-3-675B-Instruct-2512"
15+
MISTRAL_LARGE3_NVFP4_MODEL_PATH = "mistralai/Mistral-Large-3-675B-Instruct-2512-NVFP4"
1516
MISTRAL_LARGE3_EAGLE_MODEL_PATH = "mistralai/Mistral-Large-3-675B-Instruct-2512-Eagle"
1617

1718

1819
@unittest.skipIf(not is_blackwell_system(), "Requires B200")
1920
class TestMistralLarge3(unittest.TestCase):
2021
"""Unified test class for Mistral-Large-3 performance and accuracy.
2122
22-
Two variants:
23-
- basic: TP=8 + trtllm_mla backend
23+
Three variants:
24+
- basic: FP8 model + TP=8 + trtllm_mla backend
2425
- eagle: basic + EAGLE speculative decoding with draft model
26+
- nvfp4: NVFP4 model + TP=8 + trtllm_mla backend
2527
2628
Each variant runs BOTH:
2729
- Performance test (using NightlyBenchmarkRunner)
@@ -56,22 +58,33 @@ def test_mistral_large3_all_variants(self):
5658
"--speculative-num-draft-tokens=4",
5759
"--kv-cache-dtype=auto",
5860
]
61+
# TODO: add this to base args when FP8 TRTLLM moe is supported
62+
nvfp4_args = [
63+
"--moe-runner-backend=flashinfer_trtllm",
64+
]
5965

6066
variants = [
61-
# Variant: "basic" - TP=8 + trtllm_mla backend
67+
# Variant: "basic" - FP8 model + TP=8 + trtllm_mla backend
6268
ModelLaunchSettings(
63-
MISTRAL_LARGE3_MODEL_PATH,
69+
MISTRAL_LARGE3_FP8_MODEL_PATH,
6470
tp_size=8,
6571
extra_args=base_args,
6672
variant="TP8",
6773
),
68-
# Variant: "eagle" - TP=8 + trtllm_mla + EAGLE with draft model
74+
# Variant: "eagle" - FP8 model + TP=8 + trtllm_mla + EAGLE with draft model
6975
ModelLaunchSettings(
70-
MISTRAL_LARGE3_MODEL_PATH,
76+
MISTRAL_LARGE3_FP8_MODEL_PATH,
7177
tp_size=8,
7278
extra_args=base_args + eagle_args,
7379
variant="TP8+MTP",
7480
),
81+
# Variant: "nvfp4" - NVFP4 model + TP=8 + trtllm_mla backend
82+
ModelLaunchSettings(
83+
MISTRAL_LARGE3_NVFP4_MODEL_PATH,
84+
tp_size=8,
85+
extra_args=base_args + nvfp4_args,
86+
variant="NVFP4",
87+
),
7588
]
7689

7790
run_combined_tests(

0 commit comments

Comments
 (0)