Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions tensorrt_llm/_torch/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,8 +358,9 @@ def create_output(self, q: torch.Tensor):
out_dtype = q.dtype

if self.attn_backend == "TRTLLM":
if self.has_quant_scale and (self.attn.has_fp8_kv_cache
or self.attn.has_fp4_kv_cache):
# Don't use FP8 output if o_proj has pre_quant_scale - keep BF16 for better precision
if self.has_quant_scale and self.o_proj.pre_quant_scale is None and (
self.attn.has_fp8_kv_cache or self.attn.has_fp4_kv_cache):
out_dtype = torch.float8_e4m3fn
output = q.new_empty([num_tokens, hidden_size], dtype=out_dtype)
return output
Expand Down Expand Up @@ -390,8 +391,14 @@ def _attn_impl(

out_scale = None
out_scale_sf = None
if self.has_quant_scale:
# Don't set out_scale if o_proj has pre_quant_scale - this prevents FP8/FP4 output
# and keeps attention output in BF16 for better precision when applying pre_quant_scale
if self.has_quant_scale and self.o_proj.pre_quant_scale is None:
out_scale = self.o_proj.inv_input_scale
if hasattr(
self.o_proj,
'pre_quant_scale') and self.o_proj.pre_quant_scale is not None:
enable_attn_nvfp4_output = False
if self.o_proj.has_nvfp4 and self.support_nvfp4_output and enable_attn_nvfp4_output:
out_scale_sf = self.o_proj.input_scale

Expand Down
8 changes: 8 additions & 0 deletions tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,14 @@ def forward_chunk(
elif self.has_int8_woq_per_channel:
use_int8_woq_per_channel = True
elif self.has_nvfp4:
# Apply pre_quant_scale if it exists (for NVFP4_AWQ)
if hasattr(
self,
'fc31_act_scale') and self.fc31_act_scale is not None:
assert not isinstance(
x, Fp4QuantizedTensor
), "Fp4QuantizedTensor is not expected for AWQ quantization."
x = x * self.fc31_act_scale
if run_post_quant_allgather or self.enable_alltoall:
if isinstance(x, Fp4QuantizedTensor):
assert not x.is_sf_swizzled, "Fp4QuantizedTensor should not be swizzled before communication"
Expand Down
12 changes: 11 additions & 1 deletion tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,6 @@ def forward_impl(
) -> torch.Tensor:

assert x.dtype == torch.bfloat16

# DeepSeekV3 style routing
if isinstance(self.routing_method, DeepSeekV3MoeRoutingMethod):
top_k = self.routing_method.routing_impl.top_k
Expand Down Expand Up @@ -232,6 +231,14 @@ def forward_impl(
else:
x_row = x.shape[0]
x_col = x.shape[1]

# Apply pre_quant_scale if it exists (for NVFP4_AWQ)
# fc31_act_scale shape: (1, hidden_size)
# x shape: (num_tokens, hidden_size)
if hasattr(self, 'fc31_act_scale'
) and self.fc31_act_scale is not None:
x = x * self.fc31_act_scale

x, x_sf = torch.ops.trtllm.fp4_quantize(
x, self.fc31_input_scale, self.scaling_vector_size,
False, False)
Expand Down Expand Up @@ -297,6 +304,9 @@ def forward_impl(
scale_factor_use_ue8m0 = False
is_scale_factor_swizzled = False # use linear layout here

if hasattr(self,
'fc31_act_scale') and self.fc31_act_scale is not None:
x = x * self.fc31_act_scale
if not run_post_quant_allgather:
hidden_states_fp4, hidden_states_scale_linear_fp4 = (
torch.ops.trtllm.fp4_quantize(
Expand Down
86 changes: 86 additions & 0 deletions tensorrt_llm/_torch/modules/fused_moe/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -1576,6 +1576,10 @@ def create_weights(self,
requires_grad=False)
module.register_parameter("fc2_alpha", fc2_alpha)

# Optional per-channel act scale for NVFP4_AWQ (pre_quant_scale support)
# This will be initialized in load_quant_scales if pre_quant_scale exists
module.register_parameter("fc31_act_scale", None)

super().create_weights(module, weight_dtype, w3_w1_weight_shape,
w2_weight_shape)

Expand Down Expand Up @@ -1678,12 +1682,32 @@ def load_all_fp4_weight_scales_and_alphas(
dst_fc2_alpha[expert_idx])

def load_quant_scales(self, module: torch.nn.Module, weights: Dict):
# Check if pre_quant_scale exists in the checkpoint (for NVFP4_AWQ)
has_pre_quant_scale = False
if module.weight_loading_mode == MoEWeightLoadingMode.VANILLA:
# Check if any expert has pre_quant_scale
has_pre_quant_scale = f"0.w1.pre_quant_scale" in weights

# Step1: Load input scales.
tmp_fc31_input_scale = torch.empty(module.num_experts,
dtype=torch.float32)
tmp_fc2_input_scale = torch.empty(module.num_experts,
dtype=torch.float32)

# If pre_quant_scale exists, we need a per-channel act scale for fc31
# All experts share the same input, so pre_quant_scale should be identical across experts
if has_pre_quant_scale:
from ..linear import TensorParallelMode, load_weight_shard

# Create fc31_act_scale parameter (for gate_up_proj / w3_w1)
# Shape: (1, hidden_size) - single vector for all experts (they share the same input)
fc31_act_scale = nn.Parameter(torch.empty(1,
module.hidden_size,
dtype=module.dtype,
device='cuda'),
requires_grad=False)
module.register_parameter("fc31_act_scale", fc31_act_scale)

for expert_id in range(module.num_experts):
if module.weight_loading_mode == MoEWeightLoadingMode.VANILLA:
w1_input_scale = weights[f"{expert_id}.w1.input_scale"]
Expand All @@ -1710,6 +1734,68 @@ def load_quant_scales(self, module: torch.nn.Module, weights: Dict):
module.fc2_input_scale.data.copy_(
tmp_fc2_input_scale.max().reciprocal())

# Load pre_quant_scale if it exists (for NVFP4_AWQ)
if has_pre_quant_scale:
from ..linear import TensorParallelMode, load_weight_shard

# Load fc31 (w3/w1) pre_quant_scales
# All experts should have identical pre_quant_scale since they share the same input
all_w3_pre_quant_scales = []
all_w1_pre_quant_scales = []
for expert_id in module.initial_local_expert_ids:
w3_pre_quant_scale = load_weight_shard(
weights[f"{expert_id}.w3.pre_quant_scale"],
module.tp_size,
module.tp_rank,
TensorParallelMode.ROW,
device='cuda')
w1_pre_quant_scale = load_weight_shard(
weights[f"{expert_id}.w1.pre_quant_scale"],
module.tp_size,
module.tp_rank,
TensorParallelMode.ROW,
device='cuda')
all_w3_pre_quant_scales.append(w3_pre_quant_scale)
all_w1_pre_quant_scales.append(w1_pre_quant_scale)

# Verify that all experts have identical pre_quant_scale
# (they should be the same since all experts share the same input)
w3_reference = all_w3_pre_quant_scales[0]
w1_reference = all_w1_pre_quant_scales[0]

for i, (w3_scale, w1_scale) in enumerate(
zip(all_w3_pre_quant_scales[1:],
all_w1_pre_quant_scales[1:]), 1):
if not torch.allclose(
w3_scale, w3_reference, rtol=1e-5, atol=1e-8):
max_diff = (w3_scale - w3_reference).abs().max()
logger.warning(
f"MoE pre_quant_scale: expert {module.initial_local_expert_ids[i]} w3.pre_quant_scale "
f"differs from expert {module.initial_local_expert_ids[0]}! Max diff: {max_diff:.6e}. "
f"All experts should have identical pre_quant_scale since they share the same input. "
f"Using the first expert's value.")
break
if not torch.allclose(
w1_scale, w1_reference, rtol=1e-5, atol=1e-8):
max_diff = (w1_scale - w1_reference).abs().max()
logger.warning(
f"MoE pre_quant_scale: expert {module.initial_local_expert_ids[i]} w1.pre_quant_scale "
f"differs from expert {module.initial_local_expert_ids[0]}! Max diff: {max_diff:.6e}. "
f"All experts should have identical pre_quant_scale since they share the same input. "
f"Using the first expert's value.")
break

# Take the maximum pre_quant_scale between w3 and w1 from the first expert
# (all experts should have the same values)
# Shape: (hidden_size,)
# Keep on CUDA device (w3_reference and w1_reference are already on CUDA)
fc31_pre_quant_scale = torch.max(w3_reference, w1_reference).to(
dtype=module.dtype, device='cuda')

# Store as a single vector since all experts share the same pre_quant_scale
# This will be broadcasted to all tokens in the forward pass
module.fc31_act_scale.data.copy_(fc31_pre_quant_scale.unsqueeze(0))

# Step2: Load weight block scales and alphas.
self.load_all_fp4_weight_scales_and_alphas(
module, weights, module.initial_local_expert_ids,
Expand Down
59 changes: 59 additions & 0 deletions tensorrt_llm/_torch/modules/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,6 +774,10 @@ def create_weights(self, module: Linear, in_features: int,
module.inv_kv_scales = Parameter(torch.ones(3, dtype=torch.float32),
requires_grad=False)

# NOTE: Not in all linear we have this tensor - pre_quant_scale is computed as an average and merged with the
# LayerNorm for QKV and Gate/Up projection layers when possible. we can see the tensor only for o_proj and down_proj
module.pre_quant_scale = None

if bias:
module.bias = Parameter(torch.empty((out_features), dtype=dtype),
requires_grad=False)
Expand All @@ -783,10 +787,28 @@ def create_weights(self, module: Linear, in_features: int,
def apply(self, module: Linear, input: torch.Tensor,
bias: Optional[torch.Tensor]):
if isinstance(input, Fp4QuantizedTensor):
# Input is already quantized - this should not happen if pre_quant_scale exists
# because we disable FP4 output for attention output when pre_quant_scale is present
if module.pre_quant_scale is not None:
raise RuntimeError(
"Received FP4 quantized input but pre_quant_scale exists. "
"This indicates FP4 output was not properly disabled for the previous layer."
)
act_fp4, act_sf = input.fp4_tensor, input.scaling_factor
elif isinstance(input, tuple):
# Input is a tuple of (fp4_tensor, scaling_factor)
if module.pre_quant_scale is not None:
raise RuntimeError(
"Received FP4 quantized tuple input but pre_quant_scale exists. "
"This indicates FP4 output was not properly disabled for the previous layer."
)
act_fp4, act_sf = input
else:
# Input is a regular tensor () - apply pre_quant_scale if it exists (for NVFP4_AWQ)
if module.pre_quant_scale is not None:
assert input.dtype == module.pre_quant_scale.dtype, "Input dtype and pre_quant_scale dtype must match"
input = input * module.pre_quant_scale

act_fp4, act_sf = torch.ops.trtllm.fp4_quantize(
input, module.input_scale, module.scaling_vector_size, False)

Expand Down Expand Up @@ -874,6 +896,24 @@ def load_weights_vanilla(self, module: Linear, weights: List[Dict]) -> None:
copy_weight(module.alpha, alpha)
module.scalar_alpha = alpha.item()

# Load pre_quant_scale if it exists (for NVFP4_AWQ)
if "pre_quant_scale" in weights[0]:
device = module.weight.device
pre_quant_scale = load_weight_shard(
weights[0]["pre_quant_scale"],
module.tp_size,
module.tp_rank,
# pre_quant_scale applies to activation as opposed to weight, so flip tp_mode the other way around
TensorParallelMode.flip(module.tp_mode),
device,
)

module.pre_quant_scale = Parameter(
torch.ones((module.in_features, ), dtype=pre_quant_scale.dtype),
requires_grad=False).to(device=device)

copy_weight(module.pre_quant_scale, pre_quant_scale)

def load_weights_fused_qkv_linear(self, module: Linear,
weights: List[Dict]) -> None:
q_weight, k_weight, v_weight = load_weights_fused_qkv_helper(
Expand Down Expand Up @@ -930,6 +970,25 @@ def load_weights_fused_gate_up_linear(self, module: Linear,
copy_weight(module.alpha, alpha)
module.scalar_alpha = alpha.item()

# Load pre_quant_scale if it exists (for NVFP4_AWQ)
# NOTE: pre_quant_scale is the same for gate and up since modelopt checks which layer shared the same input
if "pre_quant_scale" in weights[0]:
device = module.weight.device
pre_quant_scale = load_weight_shard(
weights[0]["pre_quant_scale"],
module.tp_size,
module.tp_rank,
# pre_quant_scale applies to activation as opposed to weight, so flip tp_mode the other way around
TensorParallelMode.flip(module.tp_mode),
device,
)

module.pre_quant_scale = Parameter(
torch.ones((module.in_features, ), dtype=pre_quant_scale.dtype),
requires_grad=False).to(device=device)

copy_weight(module.pre_quant_scale, pre_quant_scale)


class W4A8NVFP4FP8LinearMethod(LinearMethodBase):

Expand Down
4 changes: 4 additions & 0 deletions tensorrt_llm/quantization/mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ class QuantAlgo(StrEnum, metaclass=BaseEnumMeta):
W4A8_MXFP4_FP8 = auto()
W4A8_MXFP4_MXFP8 = auto()
W4A16_MXFP4 = auto()
NVFP4_AWQ = auto()
NO_QUANT = auto()


Expand Down Expand Up @@ -410,6 +411,9 @@ def from_quant_algo(
quant_mode = QuantMode.from_description(use_fp8_block_scales=True)
elif quant_algo == QuantAlgo.NVFP4:
quant_mode = QuantMode.from_description(use_nvfp4=True)
elif quant_algo == QuantAlgo.NVFP4_AWQ:
# NVFP4_AWQ uses the same QuantMode as NVFP4, distinction is at QuantAlgo level
quant_mode = QuantMode.from_description(use_nvfp4=True)
elif quant_algo == QuantAlgo.W4A8_NVFP4_FP8:
quant_mode = QuantMode.from_description(use_w4a8_nvfp4_fp8=True)
elif quant_algo == QuantAlgo.W4A8_MXFP4_FP8:
Expand Down
8 changes: 8 additions & 0 deletions tensorrt_llm/quantization/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,14 @@ def quantize_layers(
else:
quant_mode = quant_config.quant_mode
init_params["quant_mode"] = quant_mode

# Auto-detect pre_quant_scale based on quant_algo
# For AWQ-based quantization methods that use pre_quant_scale
if quant_config.quant_algo in [
QuantAlgo.W4A16_AWQ, QuantAlgo.NVFP4_AWQ,
QuantAlgo.W4A8_AWQ
]:
init_params["pre_quant_scale"] = True
if "bias" in init_params and not isinstance(module,
MixtureOfExperts):
init_params["bias"] = init_params["bias"] is not None
Expand Down