Skip to content

Commit 6bbb43f

Browse files
authored
[None][feat] Add qwen3-next nvfp4 support (#8526)
Signed-off-by: jiant <107457950+JadoTu@users.noreply.github.com>
1 parent 7a552c4 commit 6bbb43f

File tree

3 files changed

+57
-3
lines changed

3 files changed

+57
-3
lines changed

examples/models/core/qwen/README.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -878,6 +878,15 @@ mpirun -n 1 --allow-run-as-root --oversubscribe python3 examples/llm-api/quickst
878878

879879
```
880880

881+
### NVFP4 quantization
882+
883+
TRTLLM supports NVFP4 precision with blocksize=16 for both activations and GEMM weights.
884+
To run the Qwen3-Next model on NVFP4 precision, use the following command
885+
```bash
886+
mpirun -n 1 --allow-run-as-root --oversubscribe python3 examples/llm-api/quickstart_advanced.py --model_dir <YOUR_MODEL_DIR> --kv_cache_fraction 0.6 --disable_kv_cache_reuse --max_batch_size 1 --tp_size 2 --trust_remote_code
887+
888+
```
889+
881890
## Notes and Troubleshooting
882891

883892
- **Model Directory:** Update `<YOUR_MODEL_DIR>` with the actual path where the model weights reside.

tensorrt_llm/_torch/modules/attention.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -413,9 +413,9 @@ def _attn_impl(
413413

414414
out_scale = None
415415
out_scale_sf = None
416-
if self.has_quant_scale:
416+
if self.has_quant_scale and not self.attn_output_gate:
417417
out_scale = self.o_proj.inv_input_scale
418-
if self.o_proj.has_nvfp4 and self.support_nvfp4_output and enable_attn_nvfp4_output:
418+
if self.o_proj.has_nvfp4 and self.support_nvfp4_output and enable_attn_nvfp4_output and not self.attn_output_gate:
419419
out_scale_sf = self.o_proj.input_scale
420420

421421
kv_scales_sf = None

tensorrt_llm/_torch/modules/linear.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from ...models.modeling_utils import QuantConfig
3030
from ..cublaslt_utils import IS_CUBLASLT_AVAILABLE
3131
from ..cute_dsl_utils import IS_CUTLASS_DSL_AVAILABLE
32-
from ..utils import Fp4QuantizedTensor
32+
from ..utils import Fp4QuantizedTensor, unswizzle_sf
3333

3434

3535
class WeightMode(str, enum.Enum):
@@ -824,6 +824,9 @@ def apply(self, module: Linear, input: torch.Tensor,
824824
act_sf,
825825
module.weight_scale,
826826
module.alpha, module.dtype)
827+
# Take the dim of out_features if padded.
828+
if output.shape[-1] > module.out_features:
829+
output = output[..., :module.out_features]
827830

828831
if bias is not None:
829832
output = output + bias
@@ -957,6 +960,48 @@ def load_weights_fused_gate_up_linear(self, module: Linear,
957960
copy_weight(module.alpha, alpha)
958961
module.scalar_alpha = alpha.item()
959962

963+
def post_load_weights(self, module: Linear):
964+
super().post_load_weights(module)
965+
"""
966+
Pad weight and weight_scale tensors to meet torch trtllm NVFP4 GEMM alignment requirements.
967+
968+
Args:
969+
row_alignment: Required row alignment (default: 32)
970+
col_alignment: Required column alignment (default: 16)
971+
"""
972+
row_alignment, col_alignment = 32, 16
973+
row_pad_size = (row_alignment - module.weight.size(0)) % row_alignment
974+
col_pad_size = (col_alignment - module.weight.size(1)) % col_alignment
975+
if row_pad_size != 0 or col_pad_size != 0:
976+
# Pad weight to meet NVFP4 GEMM kernel alignment requirements
977+
module.weight = Parameter(F.pad(module.weight,
978+
(0, col_pad_size, 0, row_pad_size),
979+
mode='constant',
980+
value=0),
981+
requires_grad=False)
982+
weight_col_size = module.weight.size(1)
983+
assert (
984+
weight_col_size * 2
985+
) % module.scaling_vector_size == 0, f"weight column size after padding {weight_col_size} must be divisible by scaling_vector_size {module.scaling_vector_size}"
986+
# Pad weight_scale to match padded weight dimensions
987+
# Padding should be performed on unswizzled weight_scale tensor
988+
scale_rows = fp4_utils.pad_up(module.out_features, 128)
989+
scale_cols = fp4_utils.pad_up(
990+
module.in_features // module.scaling_vector_size, 4)
991+
weight_scale_unswizzle = unswizzle_sf(module.weight_scale.data,
992+
scale_rows, scale_cols,
993+
module.scaling_vector_size)
994+
weight_scale_unswizzle_pad = F.pad(
995+
weight_scale_unswizzle,
996+
(0, (col_pad_size * 2) // module.scaling_vector_size, 0,
997+
row_pad_size),
998+
mode='constant',
999+
value=0)
1000+
module.weight_scale = Parameter(
1001+
torch.ops.trtllm.block_scale_interleave(
1002+
weight_scale_unswizzle_pad),
1003+
requires_grad=False)
1004+
9601005

9611006
class W4A8NVFP4FP8LinearMethod(LinearMethodBase):
9621007

0 commit comments

Comments
 (0)