-
Notifications
You must be signed in to change notification settings - Fork 825
Description
Description
trtllm_fp4_block_scale_moe produces significantly different results compared to flashinfer_cutlass_fused_moe
when running real model weights and activations from nvidia/Qwen3-Coder-480B-A35B-Instruct-NVFP4.
The discrepancy is not reproducible with randomly generated data — random inputs always produce
matching results between the two kernels. This strongly suggests a kernel-level numerical accuracy
bug triggered only by specific real-world input distributions.
The companion model nvidia/Qwen3-235B-A22B-Instruct-2507-NVFP4 is not affected: both random
and real data produce matching results between trtllm_fp4_block_scale_moe and cutlass_fused_moe.
Environment
- SGLang with
--moe-runner-backend flashinfer_trtllm - FlashInfer (version with
trtllm_fp4_block_scale_moesupport) - GPU: NVIDIA (G)B200
- Model (affected):
nvidia/Qwen3-Coder-480B-A35B-Instruct-NVFP4 - Model (unaffected):
nvidia/Qwen3-235B-A22B-Instruct-2507-NVFP4
Reproducer
A standalone reproducer script is at
python/sglang/srt/layers/moe/fused_moe_triton/repro_fp4.py.
It loads a .pt dump of real kernel inputs captured during inference and runs both
trtllm_fp4_block_scale_moe and cutlass_fused_moe side-by-side, reporting the norm difference.
Step 1 — Capture a dump
Run SGLang with the patched layer.py (which dumps kernel inputs to /tmp/dbg_moe_repro.pt on the
first call where the trtllm/cutlass norm difference exceeds 10%):
python -m sglang.launch_server \
--model-path nvidia/Qwen3-Coder-480B-A35B-Instruct-NVFP4 \
--moe-runner-backend flashinfer_trtllm \
<other args>Then send a prompt; the dump will be written to /tmp/dbg_moe_repro.pt automatically.
Step 2 — Run the reproducer
# Real data from dump → large norm difference (bug visible)
python repro_fp4.py /tmp/dbg_moe_repro_affected.pt
# Random data with same shapes → no significant difference (bug hidden)
python repro_fp4.py /tmp/dbg_moe_repro_affected.pt --random
# Random hidden states, real weights/scales/router from dump → bisect hidden-state contribution
python repro_fp4.py /tmp/dbg_moe_repro_affected.pt --random-hidden
# Unaffected model: both modes should show no significant difference
python repro_fp4.py /tmp/dbg_moe_repro_unaffected.pt
python repro_fp4.py /tmp/dbg_moe_repro_unaffected.pt --randomReproducer modes summary
| Flag | Hidden states | Weights / router / scales | Purpose |
|---|---|---|---|
| (none) | from dump | from dump | Full real-data comparison |
--random |
random | random | Confirm random data is fine |
--random-hidden |
random | from dump | Bisect hidden-state contribution |
Observed Output
Affected model (nvidia/Qwen3-Coder-480B-A35B-Instruct-NVFP4) — real data
trtllm: norm=2.419e+03 nan=0
tensor([[-0.7344, 0.2832, -0.3535, ..., -0.8594, 0.3477, 0.3887],
[-0.7344, 0.2832, -0.3535, ..., -0.8594, 0.3477, 0.3887],
[-0.7344, 0.2832, -0.3535, ..., -0.8594, 0.3477, 0.3887],
...,
[-0.7305, 0.2773, -0.3633, ..., -0.8477, 0.3066, 0.4004],
[-0.7305, 0.2773, -0.3633, ..., -0.8477, 0.3066, 0.4004],
[-0.7305, 0.2773, -0.3633, ..., -0.8477, 0.3066, 0.4004]],
device='cuda:0')
cutlass: norm=6.900e+02 nan=0
tensor([[ 0.0801, -0.0464, -0.1387, ..., -0.0227, -0.1289, 0.0483],
[ 0.0801, -0.0464, -0.1387, ..., -0.0227, -0.1289, 0.0483],
[ 0.0801, -0.0464, -0.1387, ..., -0.0227, -0.1289, 0.0483],
...,
[ 0.0767, -0.0452, -0.1436, ..., -0.0176, -0.1309, 0.0520],
[ 0.0767, -0.0452, -0.1436, ..., -0.0176, -0.1309, 0.0520],
[ 0.0767, -0.0452, -0.1436, ..., -0.0176, -0.1309, 0.0520]],
device='cuda:0')
trtllm vs cutlass: rel_norm_diff=2.506e+00 max_diff=1.875e+00 mean_diff=3.969e-0
Affected model — --random
trtllm: norm=7.091e+14 nan=0
tensor([[-2.0616e+11, -3.9460e+10, 2.2012e+10, ..., 1.9596e+10,
2.7166e+11, 1.3583e+11],
[ 1.7395e+11, 2.8991e+11, -1.2818e+10, ..., -1.5784e+11,
4.4292e+10, -5.9861e+10],
[ 1.6750e+11, 5.0332e+09, -7.5497e+09, ..., -2.8991e+11,
1.0039e+11, 5.5298e+10],
...,
[-1.5368e+10, -3.1998e+11, -1.8039e+11, ..., -1.7287e+11,
-1.2992e+11, -1.6106e+11],
[-1.2885e+10, -2.6441e+10, 8.3752e+10, ..., 3.7044e+10,
4.3594e+11, -3.7849e+10],
[-1.0670e+10, 5.8385e+09, 4.7245e+10, ..., -1.2616e+11,
-7.7309e+10, 3.7581e+10]], device='cuda:0')
cutlass: norm=7.091e+14 nan=0
tensor([[-2.0508e+11, -3.9192e+10, 2.2280e+10, ..., 1.9327e+10,
2.7166e+11, 1.3637e+11],
[ 1.7502e+11, 2.8991e+11, -1.2952e+10, ..., -1.5891e+11,
4.4560e+10, -6.0398e+10],
[ 1.6858e+11, 4.9661e+09, -7.3820e+09, ..., -2.8991e+11,
1.0093e+11, 5.6103e+10],
...,
[-1.5099e+10, -3.1998e+11, -1.7931e+11, ..., -1.7287e+11,
-1.2939e+11, -1.5999e+11],
[-1.3086e+10, -2.6844e+10, 8.3215e+10, ..., 3.7044e+10,
4.3594e+11, -3.8655e+10],
[-9.8650e+09, 6.0062e+09, 4.6976e+10, ..., -1.2563e+11,
-7.7309e+10, 3.7849e+10]], device='cuda:0')
trtllm vs cutlass: rel_norm_diff=7.259e-05 max_diff=4.295e+09 mean_diff=3.269e+08
Unaffected model (nvidia/Qwen3-235B-A22B-Instruct-2507-NVFP4) — real data
trtllm: norm=2.852e+02 nan=0
tensor([[-0.0435, 0.0116, -0.0164, ..., 0.0066, 0.1069, -0.0093],
[-0.0435, 0.0116, -0.0164, ..., 0.0066, 0.1069, -0.0093],
[-0.0435, 0.0116, -0.0164, ..., 0.0066, 0.1069, -0.0093],
...,
[-0.0432, 0.0114, -0.0167, ..., 0.0061, 0.1069, -0.0099],
[-0.0432, 0.0114, -0.0167, ..., 0.0061, 0.1069, -0.0099],
[-0.0432, 0.0114, -0.0167, ..., 0.0061, 0.1069, -0.0099]],
device='cuda:0')
cutlass: norm=2.849e+02 nan=0
tensor([[-0.0437, 0.0117, -0.0165, ..., 0.0066, 0.1074, -0.0093],
[-0.0437, 0.0117, -0.0165, ..., 0.0066, 0.1074, -0.0093],
[-0.0437, 0.0117, -0.0165, ..., 0.0066, 0.1074, -0.0093],
...,
[-0.0435, 0.0109, -0.0168, ..., 0.0061, 0.1069, -0.0095],
[-0.0435, 0.0109, -0.0168, ..., 0.0061, 0.1069, -0.0095],
[-0.0435, 0.0109, -0.0168, ..., 0.0061, 0.1069, -0.0095]],
device='cuda:0')
trtllm vs cutlass: rel_norm_diff=1.050e-03 max_diff=7.812e-03 mean_diff=3.126e-04
(Actual numbers to be filled in once .pt files are attached.)
.pt Dump Files
- Affected (
nvidia/Qwen3-Coder-480B-A35B-Instruct-NVFP4):dbg_moe_repro_affected.pt(to be attached) - Unaffected (
nvidia/Qwen3-235B-A22B-Instruct-2507-NVFP4):dbg_moe_repro_unaffected.pt(to be attached)
Each .pt file contains:
{
"hidden_states_bf16": ..., # [T, H] bfloat16 original hidden states before FP4 quantization
"router_logits": ..., # [T, E] float32
"routing_bias": ..., # [E] float32 or None
"trtllm": {
"hidden_states": ..., # [T, H//2] uint8 FP4-packed
"hidden_states_scale": ..., # [T, H//16] fp8 linear block scales
"gemm1_weights": ..., # trtllm-shuffled w13 FP4 weights
"gemm1_weights_scale": ..., # trtllm-shuffled w13 block scales
"gemm2_weights": ..., # trtllm-shuffled w2 FP4 weights
"gemm2_weights_scale": ..., # trtllm-shuffled w2 block scales
"output1_scale_gate_scalar": ..., # g1_alphas [E]
"output1_scale_scalar": ..., # g1_scale_c [E]
"output2_scale_scalar": ..., # g2_alphas [E]
"num_experts": int, "top_k": int, "n_group": int, "topk_group": int,
"intermediate_size": int, "local_expert_offset": int, "local_num_experts": int,
"routed_scaling_factor": float, "routing_method_type": int,
"tune_max_num_tokens": int,
"output": ..., # kernel result [T, H] bfloat16
},
"cutlass": {
"input": ..., # [T, H//2] FP4-packed
"input_sf": ..., # [T, H//16] fp8 swizzled block scales
"fc1_expert_weights": ..., # w13 FP4 weights (cutlass layout)
"fc2_expert_weights": ..., # w2 FP4 weights (cutlass layout)
"token_selected_experts": ..., # [T, top_k] int32
"token_final_scales": ..., # [T, top_k] float32
"w13_input_scale_quant": ..., # scalar float32
"w13_blockscale_swizzled":.., # swizzled w13 block scales
"w2_input_scale_quant": ..., # scalar float32
"w2_blockscale_swizzled": ..., # swizzled w2 block scales
"g1_alphas": ..., # [E]
"g2_alphas": ..., # [E]
"ep_size": int, "ep_rank": int, "tp_size": int, "tp_rank": int,
"tune_max_num_tokens": int, "activation_type": int,
"output": ..., # kernel result [T, H] bfloat16
},
}#2714 to track.