Skip to content

Commit 9c609f0

Browse files
authored
Update cutlass fp4 moe kernels (#1294)
<!-- .github/pull_request_template.md --> ## 📌 Description <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [ ] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [ ] I have installed the hooks with `pre-commit install`. - [ ] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. -->
1 parent f0b235c commit 9c609f0

File tree

66 files changed

+5712
-3254
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

66 files changed

+5712
-3254
lines changed

benchmarks/bench_cutlass_fused_moe.py

Lines changed: 217 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,217 @@
1+
"""
2+
Copyright (c) 2024 by FlashInfer team.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
import torch
18+
from torch.nn import functional as F
19+
from triton.testing import do_bench
20+
21+
import flashinfer
22+
import flashinfer.fused_moe as fused_moe
23+
from flashinfer import fp4_quantize
24+
25+
BATCH_SIZES = [
26+
1,
27+
2,
28+
4,
29+
8,
30+
16,
31+
24,
32+
32,
33+
48,
34+
64,
35+
96,
36+
128,
37+
256,
38+
512,
39+
1024,
40+
1536,
41+
2048,
42+
3072,
43+
4096,
44+
]
45+
46+
configs = []
47+
hidden_size = 7168
48+
num_experts = [32, 256]
49+
top_k = [8]
50+
intermediate_size = [256, 2048]
51+
FLOAT4_E2M1_MAX = 6.0
52+
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
53+
FP8_DTYPE = torch.float8_e4m3fn
54+
55+
test_configs = [
56+
{
57+
"hidden_size": 7168,
58+
"num_experts": 256,
59+
"top_k": 8,
60+
"intermediate_size": 256,
61+
},
62+
{
63+
"hidden_size": 7168,
64+
"num_experts": 32,
65+
"top_k": 8,
66+
"intermediate_size": 2048,
67+
},
68+
]
69+
70+
71+
def compute_routing(
72+
router_logits: torch.Tensor, top_k: int
73+
) -> tuple[torch.Tensor, torch.Tensor]:
74+
"""
75+
Compute routing weights and selected experts from router logits.
76+
77+
Args:
78+
router_logits (torch.Tensor): Router logits of shape [batch_size, num_experts]
79+
top_k (int): Number of experts to route to per token
80+
81+
Returns:
82+
tuple[torch.Tensor, torch.Tensor]: A tuple containing:
83+
- routing_weights: Expert weights of shape [batch_size, top_k]
84+
- selected_experts: Expert indices of shape [batch_size, top_k]
85+
"""
86+
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
87+
routing_weights, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
88+
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
89+
routing_weights = routing_weights.float()
90+
return routing_weights, selected_experts
91+
92+
93+
def bench_cutlass_fused_moe(
94+
batch_size,
95+
hidden_size,
96+
num_experts,
97+
top_k,
98+
intermediate_size,
99+
):
100+
torch.manual_seed(42)
101+
quant_blocksize = 16
102+
round_up = lambda x, y: (x + y - 1) // y * y
103+
e = num_experts
104+
m = batch_size
105+
n = intermediate_size
106+
k = hidden_size
107+
otype = torch.bfloat16
108+
wtype = torch.float8_e4m3fn
109+
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=otype) / 10
110+
w1_cutlass = torch.cat((w1[:, n:, :], w1[:, :n, :]), dim=1).contiguous()
111+
112+
sf_w1_2n = round_up(2 * n, 128)
113+
sf_w1_k = round_up(k // quant_blocksize, 4)
114+
w1_blockscale = torch.empty(
115+
(e, sf_w1_2n, sf_w1_k), device="cuda", dtype=torch.float8_e4m3fn
116+
)
117+
w1_blockscale_cutlass = torch.empty(
118+
(e, sf_w1_2n, sf_w1_k), device="cuda", dtype=torch.float8_e4m3fn
119+
)
120+
121+
w2 = torch.randn((e, k, n), device="cuda", dtype=otype) / 10
122+
sf_w2_k = round_up(k, 128)
123+
sf_w2_n = round_up(n // quant_blocksize, 4)
124+
w2_blockscale = torch.empty(
125+
(e, sf_w2_k, sf_w2_n), device="cuda", dtype=torch.float8_e4m3fn
126+
)
127+
w1_q = torch.empty((e, 2 * n, k // 2), device="cuda", dtype=torch.uint8)
128+
w1_q_cutlass = torch.empty((e, 2 * n, k // 2), device="cuda", dtype=torch.uint8)
129+
w2_q = torch.empty((e, k, n // 2), device="cuda", dtype=torch.uint8)
130+
w1_gs = torch.empty((e,), device="cuda", dtype=torch.float32)
131+
w2_gs = torch.empty((e,), device="cuda", dtype=torch.float32)
132+
133+
for expert in range(e):
134+
w1_amax = torch.abs(w1).max().to(torch.float32)
135+
w2_amax = torch.abs(w2).max().to(torch.float32)
136+
w1_gs[expert] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w1_amax
137+
w2_gs[expert] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w2_amax
138+
139+
w1_q[expert], w1_blockscale[expert] = fp4_quantize(w1[expert], w1_gs[expert])
140+
141+
w1_q_cutlass[expert], w1_blockscale_cutlass[expert] = fp4_quantize(
142+
w1_cutlass[expert], w1_gs[expert]
143+
)
144+
145+
w2_q[expert], w2_blockscale[expert] = fp4_quantize(w2[expert], w2_gs[expert])
146+
147+
x = torch.randn(m, k, dtype=otype).cuda()
148+
a1_gs = (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.abs(x).max().to(
149+
torch.float32
150+
).cuda()
151+
a1_gs = torch.tensor(1.0, device="cuda", dtype=torch.float32)
152+
a2_gs = torch.tensor(1.0, device="cuda", dtype=torch.float32)
153+
router_logits = torch.randn(m, e, dtype=otype).cuda()
154+
routing_weights, selected_experts = compute_routing(router_logits, top_k)
155+
156+
flash_output = torch.zeros_like(x)
157+
158+
quant_scales = [
159+
a1_gs,
160+
w1_blockscale.view(torch.int32),
161+
1.0 / (a1_gs * w1_gs),
162+
a2_gs,
163+
w2_blockscale.view(torch.int32),
164+
1.0 / (a2_gs * w2_gs),
165+
]
166+
hidden_states = x
167+
hidden_states, input_sf = fp4_quantize(x, a1_gs)
168+
repeats = 3
169+
from flashinfer.autotuner import AutoTuner, autotune
170+
171+
AutoTuner.get().clear_cache()
172+
with torch.inference_mode(), autotune():
173+
for _ in range(2):
174+
_ = fused_moe.cutlass_fused_moe(
175+
hidden_states,
176+
selected_experts.to(torch.int),
177+
routing_weights,
178+
w1_q.contiguous().view(torch.long),
179+
w2_q.contiguous().view(torch.long),
180+
otype,
181+
quant_scales=quant_scales,
182+
input_sf=input_sf,
183+
output=flash_output,
184+
)
185+
ms = do_bench(
186+
lambda: fused_moe.cutlass_fused_moe(
187+
hidden_states,
188+
selected_experts.to(torch.int),
189+
routing_weights,
190+
w1_q.contiguous().view(torch.long),
191+
w2_q.contiguous().view(torch.long),
192+
otype,
193+
quant_scales=quant_scales,
194+
input_sf=input_sf,
195+
output=flash_output,
196+
)
197+
)
198+
print(
199+
f"batch_size={batch_size}, num_experts={num_experts}, top_k={top_k}, intermediate_size={intermediate_size}"
200+
)
201+
print(f"execution time: {ms}ms")
202+
203+
204+
if __name__ == "__main__":
205+
for config in test_configs:
206+
hidden_size = config["hidden_size"]
207+
num_experts = config["num_experts"]
208+
top_k = config["top_k"]
209+
intermediate_size = config["intermediate_size"]
210+
for batch_size in BATCH_SIZES:
211+
bench_cutlass_fused_moe(
212+
batch_size,
213+
hidden_size,
214+
num_experts,
215+
top_k,
216+
intermediate_size,
217+
)

csrc/fused_moe/cutlass_backend/cutlass_fused_moe_instantiation.cu

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
#include "cutlass_fused_moe_kernels.cuh"
1818
#include "moe_kernels.h"
1919

20-
namespace tensorrt_llm::kernels {
20+
namespace tensorrt_llm::kernels::cutlass_kernels {
2121
// ==================== Variable batched GEMM specializations ==================================
2222
template class CutlassMoeFCRunner<float, float>;
2323

@@ -43,10 +43,13 @@ template class CutlassMoeFCRunner<__nv_fp8_e4m3, cutlass::uint4b_t, __nv_bfloat1
4343
#ifdef ENABLE_FP4
4444
template class CutlassMoeFCRunner<__nv_fp4_e2m1, __nv_fp4_e2m1, half>;
4545
template class CutlassMoeFCRunner<__nv_fp4_e2m1, __nv_fp4_e2m1, half, half>;
46+
template class CutlassMoeFCRunner<__nv_fp8_e4m3, __nv_fp4_e2m1, half>;
47+
template class CutlassMoeFCRunner<__nv_fp8_e4m3, __nv_fp4_e2m1, half, half>;
4648
#ifdef ENABLE_BF16
4749
template class CutlassMoeFCRunner<__nv_fp4_e2m1, __nv_fp4_e2m1, __nv_bfloat16>;
4850
template class CutlassMoeFCRunner<__nv_fp4_e2m1, __nv_fp4_e2m1, __nv_bfloat16, __nv_bfloat16>;
51+
template class CutlassMoeFCRunner<__nv_fp8_e4m3, __nv_fp4_e2m1, __nv_bfloat16>;
52+
template class CutlassMoeFCRunner<__nv_fp8_e4m3, __nv_fp4_e2m1, __nv_bfloat16, __nv_bfloat16>;
4953
#endif
5054
#endif
51-
52-
}; // namespace tensorrt_llm::kernels
55+
}; // namespace tensorrt_llm::kernels::cutlass_kernels

0 commit comments

Comments
 (0)