Skip to content

Commit 172240d

Browse files
meenchencodego7250
authored andcommitted
[OMNIML-2932] [feat] nvfp4 awq support (NVIDIA#8698)
Signed-off-by: weimingc <17592131+meenchen@users.noreply.github.com>
1 parent cf8f5a1 commit 172240d

File tree

8 files changed

+365
-3
lines changed

8 files changed

+365
-3
lines changed

tensorrt_llm/_torch/modules/attention.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -382,8 +382,11 @@ def create_output(self, q: torch.Tensor):
382382
out_dtype = q.dtype
383383

384384
if self.attn_backend == "TRTLLM":
385-
if self.has_quant_scale and (self.attn.has_fp8_kv_cache
386-
or self.attn.has_fp4_kv_cache):
385+
# Don't use FP8 output if o_proj has pre_quant_scale - keep BF16 for better precision
386+
has_pre_quant_scale = getattr(self.o_proj, 'pre_quant_scale',
387+
None) is not None
388+
if self.has_quant_scale and not has_pre_quant_scale and (
389+
self.attn.has_fp8_kv_cache or self.attn.has_fp4_kv_cache):
387390
out_dtype = torch.float8_e4m3fn
388391
output = q.new_empty([num_tokens, hidden_size], dtype=out_dtype)
389392
return output
@@ -414,8 +417,18 @@ def _attn_impl(
414417

415418
out_scale = None
416419
out_scale_sf = None
417-
if self.has_quant_scale and not self.attn_output_gate:
420+
has_awq_pre_quant_scale = hasattr(
421+
self.o_proj,
422+
'pre_quant_scale') and self.o_proj.pre_quant_scale is not None
423+
# Don't set out_scale if o_proj has pre_quant_scale - this prevents FP8/FP4 output
424+
# and keeps attention output in BF16 for better precision when applying pre_quant_scale
425+
if self.has_quant_scale and not self.attn_output_gate and not has_awq_pre_quant_scale:
418426
out_scale = self.o_proj.inv_input_scale
427+
if has_awq_pre_quant_scale and enable_attn_nvfp4_output:
428+
logger.warning_once(
429+
"Disable attn nvfp4 output because o_proj has pre_quant_scale for AWQ.",
430+
key="disable_attn_nvfp4_output_for_awq")
431+
enable_attn_nvfp4_output = False
419432
if self.o_proj.has_nvfp4 and self.support_nvfp4_output and enable_attn_nvfp4_output and not self.attn_output_gate:
420433
out_scale_sf = self.o_proj.input_scale
421434

tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -437,6 +437,14 @@ def forward_chunk(
437437
elif self.has_int8_woq_per_channel:
438438
use_int8_woq_per_channel = True
439439
elif self.has_nvfp4:
440+
# Apply pre_quant_scale if it exists (for NVFP4_AWQ)
441+
if hasattr(
442+
self,
443+
'fc31_act_scale') and self.fc31_act_scale is not None:
444+
assert not isinstance(
445+
x, Fp4QuantizedTensor
446+
), "Fp4QuantizedTensor is not expected for AWQ quantization."
447+
x = x * self.fc31_act_scale
440448
if run_post_quant_allgather or self.enable_alltoall:
441449
if isinstance(x, Fp4QuantizedTensor):
442450
assert not x.is_sf_swizzled, "Fp4QuantizedTensor should not be swizzled before communication"

tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,13 @@ def quantize_input(self, x, post_quant_comm: bool = True):
316316
x_row = x.shape[0]
317317
x, x_sf = x.fp4_tensor, x.scaling_factor
318318
else:
319+
# Apply pre_quant_scale if it exists (for NVFP4_AWQ)
320+
# fc31_act_scale shape: (1, hidden_size)
321+
# x shape: (num_tokens, hidden_size)
322+
if hasattr(
323+
self,
324+
'fc31_act_scale') and self.fc31_act_scale is not None:
325+
x = x * self.fc31_act_scale
319326
x_row = x.shape[0]
320327
x, x_sf = torch.ops.trtllm.fp4_quantize(
321328
x, self.fc31_input_scale, self.scaling_vector_size, False,

tensorrt_llm/_torch/modules/fused_moe/quantization.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1716,6 +1716,10 @@ def create_weights(self,
17161716
requires_grad=False)
17171717
module.register_parameter("fc2_alpha", fc2_alpha)
17181718

1719+
# Optional per-channel act scale for NVFP4_AWQ (pre_quant_scale support)
1720+
# This will be initialized in load_quant_scales if pre_quant_scale exists
1721+
module.register_parameter("fc31_act_scale", None)
1722+
17191723
super().create_weights(module, weight_dtype, w3_w1_weight_shape,
17201724
w2_weight_shape)
17211725

@@ -1834,12 +1838,30 @@ def load_all_fp4_weight_scales_and_alphas(
18341838
dst_fc2_alpha[expert_idx])
18351839

18361840
def load_quant_scales(self, module: torch.nn.Module, weights: Dict):
1841+
# Check if pre_quant_scale exists in the checkpoint (for NVFP4_AWQ)
1842+
has_pre_quant_scale = False
1843+
if module.weight_loading_mode == MoEWeightLoadingMode.VANILLA:
1844+
# Check if any expert has pre_quant_scale
1845+
has_pre_quant_scale = f"0.w1.pre_quant_scale" in weights
1846+
18371847
# Step1: Load input scales.
18381848
tmp_fc31_input_scale = torch.empty(module.num_experts,
18391849
dtype=torch.float32)
18401850
tmp_fc2_input_scale = torch.empty(module.num_experts,
18411851
dtype=torch.float32)
18421852

1853+
# If pre_quant_scale exists, we need a per-channel act scale for fc31
1854+
# All experts share the same input, so pre_quant_scale should be identical across experts
1855+
if has_pre_quant_scale:
1856+
# Create fc31_act_scale parameter (for gate_up_proj / w3_w1)
1857+
# Shape: (1, hidden_size) - single vector for all experts (they share the same input)
1858+
fc31_act_scale = nn.Parameter(torch.empty(1,
1859+
module.hidden_size,
1860+
dtype=module.dtype,
1861+
device='cuda'),
1862+
requires_grad=False)
1863+
module.register_parameter("fc31_act_scale", fc31_act_scale)
1864+
18431865
for expert_id in range(module.num_experts):
18441866
if module.weight_loading_mode == MoEWeightLoadingMode.VANILLA:
18451867
w1_input_scale = weights[f"{expert_id}.w1.input_scale"]
@@ -1866,6 +1888,66 @@ def load_quant_scales(self, module: torch.nn.Module, weights: Dict):
18661888
module.fc2_input_scale.data.copy_(
18671889
tmp_fc2_input_scale.max().reciprocal())
18681890

1891+
# Load pre_quant_scale if it exists (for NVFP4_AWQ)
1892+
if has_pre_quant_scale:
1893+
from ..linear import TensorParallelMode, load_weight_shard
1894+
1895+
device = module.fc31_act_scale.device
1896+
# Load fc31 (w3/w1) pre_quant_scales
1897+
# All experts should have identical pre_quant_scale since they share the same input
1898+
all_w3_pre_quant_scales = []
1899+
all_w1_pre_quant_scales = []
1900+
for expert_id in module.initial_local_expert_ids:
1901+
w3_pre_quant_scale = load_weight_shard(
1902+
weights[f"{expert_id}.w3.pre_quant_scale"],
1903+
module.tp_size,
1904+
module.tp_rank,
1905+
TensorParallelMode.ROW,
1906+
device=device)
1907+
w1_pre_quant_scale = load_weight_shard(
1908+
weights[f"{expert_id}.w1.pre_quant_scale"],
1909+
module.tp_size,
1910+
module.tp_rank,
1911+
TensorParallelMode.ROW,
1912+
device=device)
1913+
all_w3_pre_quant_scales.append(w3_pre_quant_scale)
1914+
all_w1_pre_quant_scales.append(w1_pre_quant_scale)
1915+
1916+
# Verify that all experts have identical pre_quant_scale
1917+
# (they should be the same since all experts share the same input)
1918+
w3_reference = all_w3_pre_quant_scales[0]
1919+
w1_reference = all_w1_pre_quant_scales[0]
1920+
1921+
def check_consistency(scale, ref_scale, scale_name, expert_id):
1922+
if not torch.allclose(scale, ref_scale, rtol=1e-5, atol=1e-8):
1923+
max_diff = (scale - ref_scale).abs().max()
1924+
msg = (
1925+
f"MoE pre_quant_scale: expert {expert_id} {scale_name} "
1926+
f"differs from expert {module.initial_local_expert_ids[0]}! Max diff: {max_diff:.6e}. "
1927+
f"All experts should have identical pre_quant_scale since they share the same input."
1928+
)
1929+
logger.error(msg)
1930+
raise ValueError(msg)
1931+
1932+
for i, (w3_scale, w1_scale) in enumerate(
1933+
zip(all_w3_pre_quant_scales[1:],
1934+
all_w1_pre_quant_scales[1:]), 1):
1935+
check_consistency(w3_scale, w3_reference, "w3.pre_quant_scale",
1936+
module.initial_local_expert_ids[i])
1937+
check_consistency(w1_scale, w1_reference, "w1.pre_quant_scale",
1938+
module.initial_local_expert_ids[i])
1939+
1940+
# Take the maximum pre_quant_scale between w3 and w1 from the first expert
1941+
# (all experts should have the same values)
1942+
# Shape: (hidden_size,)
1943+
# Keep on CUDA device (w3_reference and w1_reference are already on CUDA)
1944+
fc31_pre_quant_scale = torch.max(w3_reference, w1_reference).to(
1945+
dtype=module.dtype, device='cuda')
1946+
1947+
# Store as a single vector since all experts share the same pre_quant_scale
1948+
# This will be broadcasted to all tokens in the forward pass
1949+
module.fc31_act_scale.data.copy_(fc31_pre_quant_scale.unsqueeze(0))
1950+
18691951
# Step2: Load weight block scales and alphas.
18701952
self.load_all_fp4_weight_scales_and_alphas(
18711953
module, weights, module.initial_local_expert_ids,

tensorrt_llm/_torch/modules/linear.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -898,6 +898,10 @@ def create_weights(self, module: Linear, in_features: int,
898898
module.inv_kv_scales = Parameter(torch.ones(3, dtype=torch.float32),
899899
requires_grad=False)
900900

901+
# NOTE: Not in all linear we have this tensor - pre_quant_scale is computed as an average and merged with the
902+
# LayerNorm for QKV and Gate/Up projection layers when possible. we can see the tensor only for o_proj and down_proj
903+
module.pre_quant_scale = None
904+
901905
if bias:
902906
module.bias = Parameter(torch.empty((out_features), dtype=dtype),
903907
requires_grad=False)
@@ -907,10 +911,28 @@ def create_weights(self, module: Linear, in_features: int,
907911
def apply(self, module: Linear, input: torch.Tensor,
908912
bias: Optional[torch.Tensor]):
909913
if isinstance(input, Fp4QuantizedTensor):
914+
# Input is already quantized - this should not happen if pre_quant_scale exists
915+
# because we disable FP4 output for attention output when pre_quant_scale is present
916+
if module.pre_quant_scale is not None:
917+
raise RuntimeError(
918+
"Received FP4 quantized input but pre_quant_scale exists. "
919+
"This indicates FP4 output was not properly disabled for the previous layer."
920+
)
910921
act_fp4, act_sf = input.fp4_tensor, input.scaling_factor
911922
elif isinstance(input, tuple):
923+
# Input is a tuple of (fp4_tensor, scaling_factor)
924+
if module.pre_quant_scale is not None:
925+
raise RuntimeError(
926+
"Received FP4 quantized tuple input but pre_quant_scale exists. "
927+
"This indicates FP4 output was not properly disabled for the previous layer."
928+
)
912929
act_fp4, act_sf = input
913930
else:
931+
# Input is a regular tensor () - apply pre_quant_scale if it exists (for NVFP4_AWQ)
932+
if module.pre_quant_scale is not None:
933+
assert input.dtype == module.pre_quant_scale.dtype, "Input dtype and pre_quant_scale dtype must match"
934+
input = input * module.pre_quant_scale
935+
914936
act_fp4, act_sf = torch.ops.trtllm.fp4_quantize(
915937
input, module.input_scale, module.scaling_vector_size, False)
916938

@@ -1003,6 +1025,24 @@ def load_weights_vanilla(self, module: Linear, weights: List[Dict]) -> None:
10031025
copy_weight(module.alpha, alpha)
10041026
module.scalar_alpha = alpha.item()
10051027

1028+
# Load pre_quant_scale if it exists (for NVFP4_AWQ)
1029+
if "pre_quant_scale" in weights[0]:
1030+
device = module.weight.device
1031+
pre_quant_scale = load_weight_shard(
1032+
weights[0]["pre_quant_scale"],
1033+
module.tp_size,
1034+
module.tp_rank,
1035+
# pre_quant_scale applies to activation as opposed to weight, so flip tp_mode the other way around
1036+
TensorParallelMode.flip(module.tp_mode),
1037+
device,
1038+
)
1039+
1040+
module.pre_quant_scale = Parameter(
1041+
torch.ones((module.in_features, ), dtype=pre_quant_scale.dtype),
1042+
requires_grad=False).to(device=device)
1043+
1044+
copy_weight(module.pre_quant_scale, pre_quant_scale)
1045+
10061046
def load_weights_fused_qkv_linear(self, module: Linear,
10071047
weights: List[Dict]) -> None:
10081048
q_weight, k_weight, v_weight = load_weights_fused_qkv_helper(
@@ -1059,6 +1099,25 @@ def load_weights_fused_gate_up_linear(self, module: Linear,
10591099
copy_weight(module.alpha, alpha)
10601100
module.scalar_alpha = alpha.item()
10611101

1102+
# Load pre_quant_scale if it exists (for NVFP4_AWQ)
1103+
# NOTE: pre_quant_scale is the same for gate and up since modelopt checks which layer shared the same input
1104+
if "pre_quant_scale" in weights[0]:
1105+
device = module.weight.device
1106+
pre_quant_scale = load_weight_shard(
1107+
weights[0]["pre_quant_scale"],
1108+
module.tp_size,
1109+
module.tp_rank,
1110+
# pre_quant_scale applies to activation as opposed to weight, so flip tp_mode the other way around
1111+
TensorParallelMode.flip(module.tp_mode),
1112+
device,
1113+
)
1114+
1115+
module.pre_quant_scale = Parameter(
1116+
torch.ones((module.in_features, ), dtype=pre_quant_scale.dtype),
1117+
requires_grad=False).to(device=device)
1118+
1119+
copy_weight(module.pre_quant_scale, pre_quant_scale)
1120+
10621121
def post_load_weights(self, module: Linear):
10631122
super().post_load_weights(module)
10641123
"""

tensorrt_llm/quantization/mode.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ class QuantAlgo(StrEnum, metaclass=BaseEnumMeta):
4444
W4A8_MXFP4_FP8 = auto()
4545
W4A8_MXFP4_MXFP8 = auto()
4646
W4A16_MXFP4 = auto()
47+
NVFP4_AWQ = auto()
4748
NO_QUANT = auto()
4849

4950

@@ -410,6 +411,9 @@ def from_quant_algo(
410411
quant_mode = QuantMode.from_description(use_fp8_block_scales=True)
411412
elif quant_algo == QuantAlgo.NVFP4:
412413
quant_mode = QuantMode.from_description(use_nvfp4=True)
414+
elif quant_algo == QuantAlgo.NVFP4_AWQ:
415+
# NVFP4_AWQ uses the same QuantMode as NVFP4, distinction is at QuantAlgo level
416+
quant_mode = QuantMode.from_description(use_nvfp4=True)
413417
elif quant_algo == QuantAlgo.W4A8_NVFP4_FP8:
414418
quant_mode = QuantMode.from_description(use_w4a8_nvfp4_fp8=True)
415419
elif quant_algo == QuantAlgo.W4A8_MXFP4_FP8:

tensorrt_llm/quantization/quantize.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,14 @@ def quantize_layers(
7272
else:
7373
quant_mode = quant_config.quant_mode
7474
init_params["quant_mode"] = quant_mode
75+
76+
# Auto-detect pre_quant_scale based on quant_algo
77+
# For AWQ-based quantization methods that use pre_quant_scale
78+
if quant_config.quant_algo in [
79+
QuantAlgo.W4A16_AWQ, QuantAlgo.NVFP4_AWQ,
80+
QuantAlgo.W4A8_AWQ
81+
]:
82+
init_params["pre_quant_scale"] = True
7583
if "bias" in init_params and not isinstance(module,
7684
MixtureOfExperts):
7785
init_params["bias"] = init_params["bias"] is not None

0 commit comments

Comments
 (0)