Skip to content

Commit f1fd5c6

Browse files
authored
Trtllm-gen Fp4 MoE Autotunner (#1475)
<!-- .github/pull_request_template.md --> ## 📌 Description - Update the `AutoTuner` , `OptimizationProfile`, and `DynamicTensorSpec`. - `DynamicTensorSpec` can take multiple input tensors. - The `tensor_initializers` in `DynamicTensorSpec` defines the initialization method for dynamic tensors. Before they were all zero-initialized and this will cause IMA in trtllm-gen's routing kernels. - Add `DtypeTrtllmGen` in `flashinfer/fused_moe/core.py` - Add autotuner to trtllm-gen fp4 moe. - Relax the check on `hidden_states_scales` in trtllm-gen fp4 moe. It doesn't need to be 1D. - **When autotuning, it must be 2D.** ### TODOs - Unify the launcher for both trtllm-gen fp8 and fp4 moe. - After unifying the launchers, add auto tuner to fp8 moe. - If routing is DeepSeek V3, there will be illegal memory access. A WAR is to limit the search space to [8 ... 1024] ### Performance B200, clock speed locked at 1500mhz, 1000 warmups, 1000 iterations, `mxfp4 x mxfp8` | num_tokens | wo tuner | with tuner | diff | |------------|-----------|------------|--------| | 1 | 0.042 | 0.042 | 0.00% | | 2 | 0.057 | 0.057 | 0.00% | | 4 | 0.081 | 0.081 | 0.00% | | 8 | 0.114 | 0.105 | 7.89% | | 16 | 0.201 | 0.183 | 8.96% | | 32 | 0.274 | 0.246 | 10.22% | | 64 | 0.348 | 0.308 | 11.49% | | 128 | 0.41 | 0.365 | 10.98% | | 256 | 0.548 | 0.429 | 21.72% | | 512 | 0.576 | 0.453 | 21.35% | | 1024 | 0.651 | 0.578 | 11.21% | For `nvfp4 x nvfp4` and `mxfp4 x bf16`, there is no significant perf gain. ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. -->
1 parent dd9a333 commit f1fd5c6

File tree

8 files changed

+763
-76
lines changed

8 files changed

+763
-76
lines changed
Lines changed: 227 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,227 @@
1+
import argparse
2+
from typing import Optional, Literal
3+
import torch
4+
import numpy as np
5+
from flashinfer import (
6+
fp4_quantize,
7+
mxfp8_quantize,
8+
next_positive_power_of_2,
9+
)
10+
from flashinfer.fused_moe import trtllm_fp4_block_scale_moe
11+
from flashinfer.autotuner import autotune
12+
from flashinfer.testing.utils import bench_gpu_time
13+
from flashinfer.utils import device_support_pdl
14+
15+
16+
def get_tile_tokens_dim(num_tokens, num_experts, top_k):
17+
# Factor to account for the imbalance of the experts.
18+
# factor equals to the
19+
# max_real_num_tokens_per_expert / perfect_num_tokens_per_expert
20+
# - 1.0 means perfect expert distribution.
21+
# - > 1.0 means some experts have more
22+
# tokens than the perfect distribution.
23+
# - < 1.0 does not make sense.
24+
imbalance_factor = 1.3
25+
# Calculate the number of tokens per expert
26+
# assuming perfect distribution.
27+
num_tokens_per_expert = (num_tokens * top_k) // num_experts
28+
# Apply the imbalance factor.
29+
num_tokens_per_expert = int(num_tokens_per_expert * imbalance_factor)
30+
# And pad the number to the next power of 2.
31+
tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert)
32+
# Cap to 8-64 tokens per CTA tile
33+
# as it's the range supported by the kernel.
34+
tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
35+
return tile_tokens_dim
36+
37+
38+
def bench_trtllm_gen_fused_moe_autotuner(
39+
tune_max_num_tokens: Optional[int],
40+
quant_mode: Literal["NvFP4xNvFP4", "MxFP4xMxFP8", "MxFP4xBf16"],
41+
num_tokens: int,
42+
num_experts: int,
43+
hidden_size: int,
44+
intermediate_size: int,
45+
top_k: int,
46+
warmups: int,
47+
iterations: int,
48+
):
49+
device = torch.device("cuda:0")
50+
enable_pdl = device_support_pdl(device)
51+
routing_logits = torch.rand(num_tokens, num_experts, device=device).to(
52+
torch.bfloat16
53+
)
54+
hidden_states = torch.randn(num_tokens, hidden_size, device=device).to(
55+
torch.bfloat16
56+
)
57+
if quant_mode == "NvFP4xNvFP4":
58+
hidden_states, hidden_states_scale = fp4_quantize(
59+
hidden_states,
60+
torch.tensor([448.0 * 6.0], device=device),
61+
sf_vec_size=16,
62+
sf_use_ue8m0=False,
63+
)
64+
hidden_states_scale = hidden_states_scale.view(torch.float8_e4m3fn).reshape(
65+
num_tokens, -1
66+
)
67+
hidden_states_global_scale = 1.0 / 448.0 / 6.0
68+
elif quant_mode == "MxFP4xMxFP8":
69+
hidden_states, hidden_states_scale = mxfp8_quantize(hidden_states, False)
70+
hidden_states_scale = hidden_states_scale.view(torch.float8_e4m3fn).reshape(
71+
num_tokens, -1
72+
)
73+
hidden_states_global_scale = 1.0
74+
else: # MxFP4xBf16
75+
hidden_states_scale = None
76+
hidden_states_global_scale = 1.0
77+
78+
w13 = torch.randn(
79+
num_experts, intermediate_size * 2, hidden_size, device=device
80+
).to(torch.bfloat16)
81+
w2 = torch.randn(num_experts, hidden_size, intermediate_size, device=device).to(
82+
torch.bfloat16
83+
)
84+
if quant_mode == "NvFP4xNvFP4":
85+
w13, w13_scale = fp4_quantize(
86+
w13,
87+
torch.tensor([448.0 * 6.0], device=device),
88+
sf_vec_size=16,
89+
sf_use_ue8m0=False,
90+
)
91+
w13_scale = w13_scale.view(torch.float8_e4m3fn).reshape(
92+
num_experts, intermediate_size * 2, -1
93+
)
94+
w2, w2_scale = fp4_quantize(
95+
w2,
96+
torch.tensor([448.0 * 6.0], device=device),
97+
sf_vec_size=16,
98+
sf_use_ue8m0=False,
99+
)
100+
w2_scale = w2_scale.view(torch.float8_e4m3fn).reshape(
101+
num_experts, hidden_size, -1
102+
)
103+
w13_global_scale = 1.0 / 448.0 / 6.0
104+
w2_global_scale = 1.0 / 448.0 / 6.0
105+
else:
106+
w13, w13_scale = fp4_quantize(
107+
w13, torch.tensor([1.0], device=device), sf_vec_size=32, sf_use_ue8m0=True
108+
)
109+
w13_scale = w13_scale.view(torch.float8_e4m3fn).reshape(
110+
num_experts, intermediate_size * 2, -1
111+
)
112+
w2, w2_scale = fp4_quantize(
113+
w2, torch.tensor([1.0], device=device), sf_vec_size=32, sf_use_ue8m0=True
114+
)
115+
w2_scale = w2_scale.view(torch.float8_e4m3fn).reshape(
116+
num_experts, hidden_size, -1
117+
)
118+
w13_global_scale = 1.0
119+
w2_global_scale = 1.0
120+
bias13 = torch.randn(num_experts, intermediate_size * 2, device=device) * 10
121+
bias2 = torch.randn(num_experts, intermediate_size * 2, device=device) * 10
122+
123+
tile_tokens_dim = get_tile_tokens_dim(num_tokens, num_experts, top_k)
124+
output1_scale_scalar = torch.tensor(
125+
[hidden_states_global_scale * w13_global_scale] * num_experts, device=device
126+
)
127+
output1_scale_gate_scalar = torch.tensor(
128+
[hidden_states_global_scale * w13_global_scale] * num_experts, device=device
129+
)
130+
output2_scale_scalar = torch.tensor(
131+
[hidden_states_global_scale * w2_global_scale] * num_experts, device=device
132+
)
133+
fn = lambda: trtllm_fp4_block_scale_moe(
134+
routing_logits,
135+
None, # routing_bias
136+
hidden_states,
137+
hidden_states_scale,
138+
w13,
139+
w13_scale,
140+
bias13,
141+
None, # gemm1_alpha
142+
None, # gemm1_beta
143+
None, # gemm1_clamp_limit
144+
w2,
145+
w2_scale,
146+
bias2,
147+
output1_scale_scalar,
148+
output1_scale_gate_scalar,
149+
output2_scale_scalar,
150+
num_experts,
151+
top_k,
152+
None, # n_group
153+
None, # topk_group
154+
intermediate_size,
155+
0, # local_expert_offset
156+
num_experts,
157+
None, # routed_scaling_factor
158+
tile_tokens_dim,
159+
1,
160+
True,
161+
enable_pdl,
162+
None,
163+
num_tokens if tune_max_num_tokens is None else tune_max_num_tokens,
164+
)
165+
166+
def bench(do_autotune):
167+
# warmup
168+
with autotune(do_autotune):
169+
for _ in range(warmups):
170+
fn()
171+
ms_list = bench_gpu_time(
172+
fn,
173+
repeat_iters=iterations,
174+
)
175+
median_ms = np.median(ms_list)
176+
return median_ms
177+
178+
ms = bench(do_autotune=False)
179+
ms_tuned = bench(do_autotune=True)
180+
print(
181+
f"num tokens: {num_tokens}, num experts: {num_experts}, hidden size: {hidden_size}, intermediate size: {intermediate_size}, top k: {top_k}"
182+
)
183+
print(f"No autotune: {ms:.3f} ms; with autotune: {ms_tuned:.3f} ms")
184+
185+
186+
if __name__ == "__main__":
187+
parser = argparse.ArgumentParser()
188+
parser.add_argument(
189+
"--quant-mode",
190+
type=str,
191+
default="MxFP4xMxFP8",
192+
choices=["NvFP4xNvFP4", "MxFP4xMxFP8", "MxFP4xBf16"],
193+
help="Quantization mode",
194+
)
195+
parser.add_argument("--num-tokens", type=int, default=512, help="Number of tokens")
196+
parser.add_argument(
197+
"--tune-max-num-tokens",
198+
type=int,
199+
default=None,
200+
help="Maximum number of tokens for tunning",
201+
)
202+
parser.add_argument(
203+
"--num-experts", type=int, default=128, help="Number of experts"
204+
)
205+
parser.add_argument("--hidden-size", type=int, default=3072, help="Hidden size")
206+
parser.add_argument(
207+
"--intermediate-size", type=int, default=3072, help="Intermediate size"
208+
)
209+
parser.add_argument("--top-k", type=int, default=4, help="Top-k experts per token")
210+
parser.add_argument(
211+
"--warmups", type=int, default=100, help="Number of warmup iterations"
212+
)
213+
parser.add_argument(
214+
"--iterations", type=int, default=100, help="Number of benchmark iterations"
215+
)
216+
args = parser.parse_args()
217+
bench_trtllm_gen_fused_moe_autotuner(
218+
args.tune_max_num_tokens,
219+
args.quant_mode,
220+
args.num_tokens,
221+
args.num_experts,
222+
args.hidden_size,
223+
args.intermediate_size,
224+
args.top_k,
225+
args.warmups,
226+
args.iterations,
227+
)

csrc/trtllm_fused_moe_kernel_launcher.cu

Lines changed: 40 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -881,11 +881,10 @@ std::vector<at::Tensor> trtllm_fp4_block_scale_moe_launcher(
881881
TORCH_CHECK(hidden_states_scale.value().scalar_type() == at::ScalarType::Float8_e4m3fn,
882882
"hidden_states_scale must be fp8.");
883883

884-
TORCH_CHECK(hidden_states_scale.value().dim() == 1, "hidden_states_scale must be 1D.");
885-
TORCH_CHECK(hidden_states_scale.value().sizes()[0] ==
886-
tensorrt_llm::computeLinearLayoutSFSize(args.num_tokens,
887-
args.hidden_size / sf_vec_size),
888-
"hidden_states_scale has incorrect size");
884+
TORCH_CHECK(
885+
hidden_states_scale.value().numel() == tensorrt_llm::computeLinearLayoutSFSize(
886+
args.num_tokens, args.hidden_size / sf_vec_size),
887+
"hidden_states_scale has incorrect size");
889888
}
890889

891890
TORCH_CHECK(gemm1_weights.scalar_type() == torch_ext::FLOAT4_E2M1X2,
@@ -1059,7 +1058,8 @@ std::vector<at::Tensor> trtllm_fp4_block_scale_moe(
10591058
std::optional<int64_t> n_group, std::optional<int64_t> topk_group, int64_t intermediate_size,
10601059
int64_t local_expert_offset, int64_t local_num_experts,
10611060
std::optional<double> routed_scaling_factor, int64_t tile_tokens_dim,
1062-
int64_t routing_method_type, bool do_finalize, bool enable_pdl, at::Tensor& output) {
1061+
int64_t routing_method_type, bool do_finalize, bool enable_pdl, at::Tensor& output,
1062+
int64_t config_index) {
10631063
using RunnerType = tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner;
10641064

10651065
int const num_tokens = hidden_states.sizes()[0];
@@ -1112,8 +1112,10 @@ std::vector<at::Tensor> trtllm_fp4_block_scale_moe(
11121112
mDtypeAct, mDtypeWeights, mUseDeepSeekFp8, (int32_t)tile_tokens_dim,
11131113
tensorrt_llm::kernels::ActType::SwiGlu, /*useShuffledMatrixA*/ true);
11141114

1115-
auto const moeConfigIndex = mRunner->getDefaultValidConfigIndex(
1116-
top_k, hidden_size, intermediate_size, local_num_experts, num_tokens);
1115+
if (config_index == -1) {
1116+
config_index = mRunner->getDefaultValidConfigIndex(top_k, hidden_size, intermediate_size,
1117+
local_num_experts, num_tokens);
1118+
}
11171119

11181120
return trtllm_fp4_block_scale_moe_launcher(
11191121
routing_logits, topk_ids, expert_weights, routing_bias, hidden_states, hidden_states_scale,
@@ -1122,7 +1124,34 @@ std::vector<at::Tensor> trtllm_fp4_block_scale_moe(
11221124
output1_scales_gate_scalar, output2_scales_scalar, num_experts, top_k, n_group, topk_group,
11231125
intermediate_size, local_expert_offset, local_num_experts, routed_scaling_factor,
11241126
tile_tokens_dim, routing_method_type, do_finalize, *mRunner, mDtypeAct, mDtypeWeights,
1125-
moeConfigIndex, enable_pdl, output);
1127+
config_index, enable_pdl, output);
1128+
}
1129+
1130+
int64_t trtllm_get_default_moe_configs(int64_t const tile_tokens_dim, int64_t const dtype_act_,
1131+
int64_t const dtype_weights_, bool const useDeepSeekFp8,
1132+
int64_t const top_k, int64_t const hidden_size,
1133+
int64_t const intermediate_size,
1134+
int64_t const num_local_experts, int64_t const num_tokens) {
1135+
auto dtype_act = static_cast<btg::Dtype>(dtype_act_);
1136+
auto dtype_weights = static_cast<btg::Dtype>(dtype_weights_);
1137+
tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner moe_runner(
1138+
dtype_act, dtype_weights, useDeepSeekFp8, (int32_t)tile_tokens_dim,
1139+
tensorrt_llm::kernels::ActType::SwiGlu, /*useShuffledMatrixA*/ true);
1140+
return moe_runner.getDefaultValidConfigIndex(top_k, hidden_size, intermediate_size,
1141+
num_local_experts, num_tokens);
1142+
}
1143+
1144+
std::vector<int64_t> trtllm_get_valid_moe_configs(
1145+
int64_t const tile_tokens_dim, int64_t const dtype_act_, int64_t const dtype_weights_,
1146+
bool const useDeepSeekFp8, int64_t const top_k, int64_t const hidden_size,
1147+
int64_t const intermediate_size, int64_t const num_local_experts, int64_t const num_tokens) {
1148+
auto dtype_act = static_cast<btg::Dtype>(dtype_act_);
1149+
auto dtype_weights = static_cast<btg::Dtype>(dtype_weights_);
1150+
tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner moe_runner(
1151+
dtype_act, dtype_weights, useDeepSeekFp8, (int32_t)tile_tokens_dim,
1152+
tensorrt_llm::kernels::ActType::SwiGlu, /*useShuffledMatrixA*/ true);
1153+
return moe_runner.getValidConfigIndices(top_k, hidden_size, intermediate_size, num_local_experts,
1154+
num_tokens);
11261155
}
11271156

11281157
namespace trtllm_cubin_loader {
@@ -1133,6 +1162,8 @@ TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) {
11331162
m.def("trtllm_fp8_per_tensor_scale_moe", trtllm_fp8_per_tensor_scale_moe);
11341163
m.def("trtllm_fp8_block_scale_moe", trtllm_fp8_block_scale_moe);
11351164
m.def("trtllm_fp4_block_scale_moe", trtllm_fp4_block_scale_moe);
1165+
m.def("trtllm_get_default_moe_configs", trtllm_get_default_moe_configs);
1166+
m.def("trtllm_get_valid_moe_configs", trtllm_get_valid_moe_configs);
11361167
}
11371168

11381169
} // namespace flashinfer

0 commit comments

Comments
 (0)