Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions tensorrt_llm/_torch/auto_deploy/config/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,9 @@ transforms:
stage: post_load_fusion
enabled: true
backend: trtllm
fuse_nvfp4_moe:
stage: post_load_fusion
enabled: false
fuse_allreduce_residual_rmsnorm:
stage: post_load_fusion
# TODO (lucaslie): add backend selection as part of configurable inference optimizers
Expand Down
152 changes: 112 additions & 40 deletions tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,15 @@
# limitations under the License.


import math

import torch

from tensorrt_llm._torch.auto_deploy.custom_ops.quant import (
TRTLLM_NVFP4_COLUMN_SIZE,
TRTLLM_NVFP4_ROW_SIZE,
TRTLLM_NVFP4_SCALING_VECTOR_SIZE,
)
from tensorrt_llm._torch.utils import ActivationType


Expand Down Expand Up @@ -212,17 +219,17 @@ def trtllm_quant_fp8_moe_fused_fake(

@torch.library.custom_op("auto_deploy::trtllm_quant_nvfp4_moe_fused", mutates_args=())
def trtllm_quant_nvfp4_moe_fused(
x: torch.Tensor, # [B, S, H] or [B*S, H], 16-bit float
x: torch.Tensor,
selected_experts: torch.Tensor,
routing_weights: torch.Tensor,
fc1_expert_weights_fp4: torch.Tensor, # [E, 2*I, H] or [E, I, H]; uint8
fc2_expert_weights_fp4: torch.Tensor, # [E, H, I]; uint8
fc1_weight_blockscale_fp8: torch.Tensor, # Global scale for fc1 (scalar)
fc2_weight_blockscale_fp8: torch.Tensor, # Global scale for w2 (scalar)
fc1_act_global_scale: torch.Tensor, # Global scale for FC1 activations
fc2_act_global_scale: torch.Tensor, # Global scale for FC2 activations
fc1_alpha: torch.Tensor, # Precomputed FC1 alpha (1.0 / (fc1_act_global_scale * fc1_weight_blockscale_fp8))
fc2_alpha: torch.Tensor, # Precomputed FC2 alpha (1.0 / (fc2_act_global_scale * fc2_weight_blockscale_fp8))
fc1_expert_weights_fp4: torch.Tensor,
fc2_expert_weights_fp4: torch.Tensor,
fc1_weight_blockscale_fp8: torch.Tensor,
fc2_weight_blockscale_fp8: torch.Tensor,
fc1_act_global_scale: torch.Tensor,
fc2_act_global_scale: torch.Tensor,
fc1_alpha: torch.Tensor,
fc2_alpha: torch.Tensor,
is_gated_mlp: bool = True,
act_fn: int = int(ActivationType.Silu),
) -> torch.Tensor:
Expand All @@ -234,28 +241,100 @@ def trtllm_quant_nvfp4_moe_fused(
For mlp:
y = act(x @ w1.T) @ w2.T # act := ReLU^2

Notes:
- FC1 implements: fc1_output = (act(x @ w1.T) * (x @ w3.T)) or fc1_output = act(x @ w1.T)
- FC2 implements: fc2_output = fc1_output @ w2.T
- FC1 weights are concatenated w3 and w1 if gated_mlp, otherwise w1
- FP4 elements pairs are packed as a single uint8 element

FC1 implements: fc1_output = (act(x @ w1.T) * (x @ w3.T)) or fc1_output = act(x @ w1.T)
FC2 implements: fc2_output = fc1_output @ w2.T

Parameters:
x: BF16/FP16 input tensor of shape (B, H) or (B, S, H)
selected_experts: Expert indices (B*S, TOP_K)
routing_weights: Routing weights (B*S, TOP_K)
fc1_expert_weights_fp4: FP4 FC1 weights [E, 2*I, H/2] or [E, I, H/2]; packed uint8
fc2_expert_weights_fp4: FP4 FC2 weights [E, H, I/2]; packed uint8
fc1_weight_blockscale_fp8: Block scales for FC1 weights (w1 or cat(w3, w1))
fc2_weight_blockscale_fp8: Block scales for FC2 weights (w2)
fc1_act_global_scale: Global scale for FC1 activations (scalar)
fc2_act_global_scale: Global scale for FC2 activations (scalar)
fc1_alpha: FC1 dequant scales = 1.0 / (fc1_act_global_scale * fc1_weight_global_scale)
fc2_alpha: FC2 dequant scales = 1.0 / (fc2_act_global_scale * fc2_weight_global_scale)
mlp_style: "gated_mlp" or "mlp"
act_fn: "silu" for gated_mlp, "relu2" for mlp
"""
NVFP4_BLOCK_SIZE = 16
NVFP4_BLOCK_SIZE = TRTLLM_NVFP4_SCALING_VECTOR_SIZE
FP4_PER_UINT8 = 2

activation_type = ActivationType.Swiglu
if is_gated_mlp:
if act_fn in [ActivationType.Silu, ActivationType.Swiglu]:
activation_type = ActivationType.Swiglu
else:
raise ValueError(
f"Unsupported activation '{ActivationType(act_fn).name}' for gated_mlp. Use 'silu'."
)
_, fc1_inter_size, _ = fc1_expert_weights_fp4.shape
n_experts, hidden_size, inter_size = fc2_expert_weights_fp4.shape

# Convert the inter_size from number of uint8 elements to number of FP4 elements.
inter_size *= FP4_PER_UINT8

# Validate shapes and padding requirements as defined by the cutlass kernel.
assert fc1_weight_blockscale_fp8.ndim == 3, "fc1_weight_blockscale_fp8 must be 3D"
assert fc2_weight_blockscale_fp8.ndim == 3, "fc2_weight_blockscale_fp8 must be 3D"
assert fc1_weight_blockscale_fp8.size(1) % TRTLLM_NVFP4_ROW_SIZE == 0
assert fc2_weight_blockscale_fp8.size(1) % TRTLLM_NVFP4_ROW_SIZE == 0
assert fc1_weight_blockscale_fp8.size(2) % TRTLLM_NVFP4_COLUMN_SIZE == 0
assert fc2_weight_blockscale_fp8.size(2) % TRTLLM_NVFP4_COLUMN_SIZE == 0

_validate_mlp_style_and_act_fn(is_gated_mlp, act_fn)
act_fn = ActivationType.Swiglu if act_fn == ActivationType.Silu else act_fn

if x.dtype in (torch.float16, torch.bfloat16):
x_q_fp4, input_blockscale = torch.ops.trtllm.fp4_quantize(
x, fc1_act_global_scale, NVFP4_BLOCK_SIZE
)
output_dtype = x.dtype
else:
if act_fn == ActivationType.Relu2:
activation_type = ActivationType.Relu2
else:
raise ValueError(
f"Unsupported activation '{ActivationType(act_fn).name}' for mlp. Use 'relu2'."
)
x_q_fp4 = x
input_blockscale = None
output_dtype = x.dtype

# Pad inter_size to be divisible by 128
inter_size_padded = math.ceil(inter_size / TRTLLM_NVFP4_ROW_SIZE) * TRTLLM_NVFP4_ROW_SIZE
fc1_inter_size_padded = (
math.ceil(fc1_inter_size / TRTLLM_NVFP4_ROW_SIZE) * TRTLLM_NVFP4_ROW_SIZE
)
hidden_size_padded = (
math.ceil(hidden_size / TRTLLM_NVFP4_COLUMN_SIZE) * TRTLLM_NVFP4_COLUMN_SIZE
)

inter_size_needs_padding = (is_gated_mlp and fc1_inter_size_padded != fc1_inter_size) or (
not is_gated_mlp and inter_size_padded != inter_size
)
hidden_size_needs_padding = hidden_size % TRTLLM_NVFP4_COLUMN_SIZE != 0
if inter_size_needs_padding or hidden_size_needs_padding:
assert False, "See https://github.com/NVIDIA/TensorRT-LLM/issues/10331"
# fc1_expert_weights_fp4: [E, I, H] or [E, 2*I, H]
fc1_padded = fc1_expert_weights_fp4.new_zeros(
fc1_expert_weights_fp4.size(0),
fc1_inter_size_padded,
hidden_size_padded // FP4_PER_UINT8,
)
fc1_padded[:, :fc1_inter_size, :] = fc1_expert_weights_fp4
fc1_expert_weights_fp4 = fc1_padded

# fc2_expert_weights_fp4: [E, H, I]
fc2_padded = fc2_expert_weights_fp4.new_zeros(
n_experts, hidden_size_padded, inter_size_padded // FP4_PER_UINT8
)

assert inter_size % NVFP4_BLOCK_SIZE == 0, (
f"inter_size {inter_size} must be divisible by {NVFP4_BLOCK_SIZE}"
)

fc2_padded[:, :, : inter_size // FP4_PER_UINT8] = fc2_expert_weights_fp4
fc2_expert_weights_fp4 = fc2_padded

fc2_blockscale_fp8_padded = fc2_weight_blockscale_fp8.new_zeros(
n_experts, hidden_size_padded, inter_size_padded // NVFP4_BLOCK_SIZE
)
fc2_blockscale_fp8_padded[:, :, : inter_size // NVFP4_BLOCK_SIZE] = (
fc2_weight_blockscale_fp8
)
fc2_weight_blockscale_fp8 = fc2_blockscale_fp8_padded

# quant_scales is described by this code:
# https://github.com/NVIDIA/TensorRT-LLM/blob/c9771ebb997683c08b26bbba796a7fc6aff09d93/cpp/tensorrt_llm/thop/moeOp.cpp#L1015
Expand All @@ -270,26 +349,19 @@ def trtllm_quant_nvfp4_moe_fused(
fc2_alpha, # torch.float32; [E]
]

if x.dtype in (torch.float16, torch.bfloat16):
x_q_fp4, input_blockscale = torch.ops.trtllm.fp4_quantize(
x, fc1_act_global_scale, NVFP4_BLOCK_SIZE
)
output_dtype = x.dtype
else:
x_q_fp4 = x

trtllm_output = torch.ops.trtllm.fused_moe(
x_q_fp4,
selected_experts.to(torch.int),
routing_weights,
fc1_expert_weights=fc1_expert_weights_fp4,
x_q_fp4.view(torch.long),
selected_experts.to(torch.int32),
routing_weights.to(torch.float32),
# Groups of 16 FP4 weight elements are packed as a single int64 element (see isNvfp4Quant in moeOp.cpp)
fc1_expert_weights=fc1_expert_weights_fp4.view(torch.long),
fc1_expert_biases=None,
fc2_expert_weights=fc2_expert_weights_fp4.view(torch.long),
fc2_expert_biases=None,
output_dtype=output_dtype,
quant_scales=quant_scales,
input_sf=input_blockscale,
activation_type=activation_type,
activation_type=act_fn,
)[0].view(x.shape)

return trtllm_output
Expand Down
2 changes: 2 additions & 0 deletions tensorrt_llm/_torch/auto_deploy/custom_ops/quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
TRTLLM_FP4_OP_AVAILABLE = True

TRTLLM_NVFP4_SCALING_VECTOR_SIZE = 16
TRTLLM_NVFP4_ROW_SIZE = 128
TRTLLM_NVFP4_COLUMN_SIZE = 4


@torch.library.custom_op("auto_deploy::torch_quant_fn", mutates_args=())
Expand Down
Loading