Skip to content

Commit 966231d

Browse files
[#9626][feat] Add an auto-deploy transform for using cutlass FP4 MoE kernels (#10304)
Add a transform to relace torch.ops.auto_deploy.torch_quant_nvfp4_moe with the optimized torch.ops.auto_deploy.trtllm_quant_nvfp4_moe_fused. Currently generates the wrong results when the number of rows in MoE FC1 weights is not divisible by 128, so torch.ops.auto_deploy.trtllm_quant_nvfp4_moe_fused is not set as the default FP4 MoE implementation (i.e. the transform is disabled). Signed-off-by: Neta Zmora <[email protected]>
1 parent 965578c commit 966231d

File tree

6 files changed

+519
-218
lines changed

6 files changed

+519
-218
lines changed

tensorrt_llm/_torch/auto_deploy/config/default.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,9 @@ transforms:
126126
stage: post_load_fusion
127127
enabled: true
128128
backend: trtllm
129+
fuse_nvfp4_moe:
130+
stage: post_load_fusion
131+
enabled: false
129132
fuse_allreduce_residual_rmsnorm:
130133
stage: post_load_fusion
131134
# TODO (lucaslie): add backend selection as part of configurable inference optimizers

tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py

Lines changed: 112 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,15 @@
1414
# limitations under the License.
1515

1616

17+
import math
18+
1719
import torch
1820

21+
from tensorrt_llm._torch.auto_deploy.custom_ops.quant import (
22+
TRTLLM_NVFP4_COLUMN_SIZE,
23+
TRTLLM_NVFP4_ROW_SIZE,
24+
TRTLLM_NVFP4_SCALING_VECTOR_SIZE,
25+
)
1926
from tensorrt_llm._torch.utils import ActivationType
2027

2128

@@ -212,17 +219,17 @@ def trtllm_quant_fp8_moe_fused_fake(
212219

213220
@torch.library.custom_op("auto_deploy::trtllm_quant_nvfp4_moe_fused", mutates_args=())
214221
def trtllm_quant_nvfp4_moe_fused(
215-
x: torch.Tensor, # [B, S, H] or [B*S, H], 16-bit float
222+
x: torch.Tensor,
216223
selected_experts: torch.Tensor,
217224
routing_weights: torch.Tensor,
218-
fc1_expert_weights_fp4: torch.Tensor, # [E, 2*I, H] or [E, I, H]; uint8
219-
fc2_expert_weights_fp4: torch.Tensor, # [E, H, I]; uint8
220-
fc1_weight_blockscale_fp8: torch.Tensor, # Global scale for fc1 (scalar)
221-
fc2_weight_blockscale_fp8: torch.Tensor, # Global scale for w2 (scalar)
222-
fc1_act_global_scale: torch.Tensor, # Global scale for FC1 activations
223-
fc2_act_global_scale: torch.Tensor, # Global scale for FC2 activations
224-
fc1_alpha: torch.Tensor, # Precomputed FC1 alpha (1.0 / (fc1_act_global_scale * fc1_weight_blockscale_fp8))
225-
fc2_alpha: torch.Tensor, # Precomputed FC2 alpha (1.0 / (fc2_act_global_scale * fc2_weight_blockscale_fp8))
225+
fc1_expert_weights_fp4: torch.Tensor,
226+
fc2_expert_weights_fp4: torch.Tensor,
227+
fc1_weight_blockscale_fp8: torch.Tensor,
228+
fc2_weight_blockscale_fp8: torch.Tensor,
229+
fc1_act_global_scale: torch.Tensor,
230+
fc2_act_global_scale: torch.Tensor,
231+
fc1_alpha: torch.Tensor,
232+
fc2_alpha: torch.Tensor,
226233
is_gated_mlp: bool = True,
227234
act_fn: int = int(ActivationType.Silu),
228235
) -> torch.Tensor:
@@ -234,28 +241,100 @@ def trtllm_quant_nvfp4_moe_fused(
234241
For mlp:
235242
y = act(x @ w1.T) @ w2.T # act := ReLU^2
236243
244+
Notes:
245+
- FC1 implements: fc1_output = (act(x @ w1.T) * (x @ w3.T)) or fc1_output = act(x @ w1.T)
246+
- FC2 implements: fc2_output = fc1_output @ w2.T
247+
- FC1 weights are concatenated w3 and w1 if gated_mlp, otherwise w1
248+
- FP4 elements pairs are packed as a single uint8 element
237249
238-
FC1 implements: fc1_output = (act(x @ w1.T) * (x @ w3.T)) or fc1_output = act(x @ w1.T)
239-
FC2 implements: fc2_output = fc1_output @ w2.T
240-
250+
Parameters:
251+
x: BF16/FP16 input tensor of shape (B, H) or (B, S, H)
252+
selected_experts: Expert indices (B*S, TOP_K)
253+
routing_weights: Routing weights (B*S, TOP_K)
254+
fc1_expert_weights_fp4: FP4 FC1 weights [E, 2*I, H/2] or [E, I, H/2]; packed uint8
255+
fc2_expert_weights_fp4: FP4 FC2 weights [E, H, I/2]; packed uint8
256+
fc1_weight_blockscale_fp8: Block scales for FC1 weights (w1 or cat(w3, w1))
257+
fc2_weight_blockscale_fp8: Block scales for FC2 weights (w2)
258+
fc1_act_global_scale: Global scale for FC1 activations (scalar)
259+
fc2_act_global_scale: Global scale for FC2 activations (scalar)
260+
fc1_alpha: FC1 dequant scales = 1.0 / (fc1_act_global_scale * fc1_weight_global_scale)
261+
fc2_alpha: FC2 dequant scales = 1.0 / (fc2_act_global_scale * fc2_weight_global_scale)
262+
mlp_style: "gated_mlp" or "mlp"
263+
act_fn: "silu" for gated_mlp, "relu2" for mlp
241264
"""
242-
NVFP4_BLOCK_SIZE = 16
265+
NVFP4_BLOCK_SIZE = TRTLLM_NVFP4_SCALING_VECTOR_SIZE
266+
FP4_PER_UINT8 = 2
243267

244-
activation_type = ActivationType.Swiglu
245-
if is_gated_mlp:
246-
if act_fn in [ActivationType.Silu, ActivationType.Swiglu]:
247-
activation_type = ActivationType.Swiglu
248-
else:
249-
raise ValueError(
250-
f"Unsupported activation '{ActivationType(act_fn).name}' for gated_mlp. Use 'silu'."
251-
)
268+
_, fc1_inter_size, _ = fc1_expert_weights_fp4.shape
269+
n_experts, hidden_size, inter_size = fc2_expert_weights_fp4.shape
270+
271+
# Convert the inter_size from number of uint8 elements to number of FP4 elements.
272+
inter_size *= FP4_PER_UINT8
273+
274+
# Validate shapes and padding requirements as defined by the cutlass kernel.
275+
assert fc1_weight_blockscale_fp8.ndim == 3, "fc1_weight_blockscale_fp8 must be 3D"
276+
assert fc2_weight_blockscale_fp8.ndim == 3, "fc2_weight_blockscale_fp8 must be 3D"
277+
assert fc1_weight_blockscale_fp8.size(1) % TRTLLM_NVFP4_ROW_SIZE == 0
278+
assert fc2_weight_blockscale_fp8.size(1) % TRTLLM_NVFP4_ROW_SIZE == 0
279+
assert fc1_weight_blockscale_fp8.size(2) % TRTLLM_NVFP4_COLUMN_SIZE == 0
280+
assert fc2_weight_blockscale_fp8.size(2) % TRTLLM_NVFP4_COLUMN_SIZE == 0
281+
282+
_validate_mlp_style_and_act_fn(is_gated_mlp, act_fn)
283+
act_fn = ActivationType.Swiglu if act_fn == ActivationType.Silu else act_fn
284+
285+
if x.dtype in (torch.float16, torch.bfloat16):
286+
x_q_fp4, input_blockscale = torch.ops.trtllm.fp4_quantize(
287+
x, fc1_act_global_scale, NVFP4_BLOCK_SIZE
288+
)
289+
output_dtype = x.dtype
252290
else:
253-
if act_fn == ActivationType.Relu2:
254-
activation_type = ActivationType.Relu2
255-
else:
256-
raise ValueError(
257-
f"Unsupported activation '{ActivationType(act_fn).name}' for mlp. Use 'relu2'."
258-
)
291+
x_q_fp4 = x
292+
input_blockscale = None
293+
output_dtype = x.dtype
294+
295+
# Pad inter_size to be divisible by 128
296+
inter_size_padded = math.ceil(inter_size / TRTLLM_NVFP4_ROW_SIZE) * TRTLLM_NVFP4_ROW_SIZE
297+
fc1_inter_size_padded = (
298+
math.ceil(fc1_inter_size / TRTLLM_NVFP4_ROW_SIZE) * TRTLLM_NVFP4_ROW_SIZE
299+
)
300+
hidden_size_padded = (
301+
math.ceil(hidden_size / TRTLLM_NVFP4_COLUMN_SIZE) * TRTLLM_NVFP4_COLUMN_SIZE
302+
)
303+
304+
inter_size_needs_padding = (is_gated_mlp and fc1_inter_size_padded != fc1_inter_size) or (
305+
not is_gated_mlp and inter_size_padded != inter_size
306+
)
307+
hidden_size_needs_padding = hidden_size % TRTLLM_NVFP4_COLUMN_SIZE != 0
308+
if inter_size_needs_padding or hidden_size_needs_padding:
309+
assert False, "See https://github.com/NVIDIA/TensorRT-LLM/issues/10331"
310+
# fc1_expert_weights_fp4: [E, I, H] or [E, 2*I, H]
311+
fc1_padded = fc1_expert_weights_fp4.new_zeros(
312+
fc1_expert_weights_fp4.size(0),
313+
fc1_inter_size_padded,
314+
hidden_size_padded // FP4_PER_UINT8,
315+
)
316+
fc1_padded[:, :fc1_inter_size, :] = fc1_expert_weights_fp4
317+
fc1_expert_weights_fp4 = fc1_padded
318+
319+
# fc2_expert_weights_fp4: [E, H, I]
320+
fc2_padded = fc2_expert_weights_fp4.new_zeros(
321+
n_experts, hidden_size_padded, inter_size_padded // FP4_PER_UINT8
322+
)
323+
324+
assert inter_size % NVFP4_BLOCK_SIZE == 0, (
325+
f"inter_size {inter_size} must be divisible by {NVFP4_BLOCK_SIZE}"
326+
)
327+
328+
fc2_padded[:, :, : inter_size // FP4_PER_UINT8] = fc2_expert_weights_fp4
329+
fc2_expert_weights_fp4 = fc2_padded
330+
331+
fc2_blockscale_fp8_padded = fc2_weight_blockscale_fp8.new_zeros(
332+
n_experts, hidden_size_padded, inter_size_padded // NVFP4_BLOCK_SIZE
333+
)
334+
fc2_blockscale_fp8_padded[:, :, : inter_size // NVFP4_BLOCK_SIZE] = (
335+
fc2_weight_blockscale_fp8
336+
)
337+
fc2_weight_blockscale_fp8 = fc2_blockscale_fp8_padded
259338

260339
# quant_scales is described by this code:
261340
# https://github.com/NVIDIA/TensorRT-LLM/blob/c9771ebb997683c08b26bbba796a7fc6aff09d93/cpp/tensorrt_llm/thop/moeOp.cpp#L1015
@@ -270,26 +349,19 @@ def trtllm_quant_nvfp4_moe_fused(
270349
fc2_alpha, # torch.float32; [E]
271350
]
272351

273-
if x.dtype in (torch.float16, torch.bfloat16):
274-
x_q_fp4, input_blockscale = torch.ops.trtllm.fp4_quantize(
275-
x, fc1_act_global_scale, NVFP4_BLOCK_SIZE
276-
)
277-
output_dtype = x.dtype
278-
else:
279-
x_q_fp4 = x
280-
281352
trtllm_output = torch.ops.trtllm.fused_moe(
282-
x_q_fp4,
283-
selected_experts.to(torch.int),
284-
routing_weights,
285-
fc1_expert_weights=fc1_expert_weights_fp4,
353+
x_q_fp4.view(torch.long),
354+
selected_experts.to(torch.int32),
355+
routing_weights.to(torch.float32),
356+
# Groups of 16 FP4 weight elements are packed as a single int64 element (see isNvfp4Quant in moeOp.cpp)
357+
fc1_expert_weights=fc1_expert_weights_fp4.view(torch.long),
286358
fc1_expert_biases=None,
287359
fc2_expert_weights=fc2_expert_weights_fp4.view(torch.long),
288360
fc2_expert_biases=None,
289361
output_dtype=output_dtype,
290362
quant_scales=quant_scales,
291363
input_sf=input_blockscale,
292-
activation_type=activation_type,
364+
activation_type=act_fn,
293365
)[0].view(x.shape)
294366

295367
return trtllm_output

tensorrt_llm/_torch/auto_deploy/custom_ops/quant.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
TRTLLM_FP4_OP_AVAILABLE = True
1313

1414
TRTLLM_NVFP4_SCALING_VECTOR_SIZE = 16
15+
TRTLLM_NVFP4_ROW_SIZE = 128
16+
TRTLLM_NVFP4_COLUMN_SIZE = 4
1517

1618

1719
@torch.library.custom_op("auto_deploy::torch_quant_fn", mutates_args=())

0 commit comments

Comments
 (0)