Skip to content

Commit 263495f

Browse files
yewentao256ProExpertProg
authored andcommitted
[Perf] Optimize cutlass moe problem size calculation, 5.3% E2E Throughput improvement, 2.2% TTFT improvement (vllm-project#31830)
Signed-off-by: yewentao256 <zhyanwentao@126.com> Signed-off-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Signed-off-by: dsuhinin <suhinin.dmitriy@gmail.com>
1 parent 527f8db commit 263495f

File tree

6 files changed

+172
-63
lines changed

6 files changed

+172
-63
lines changed

csrc/ops.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,11 @@ void get_cutlass_moe_mm_problem_sizes(
265265
const int64_t k, const std::optional<torch::Tensor>& blockscale_offsets,
266266
std::optional<bool> force_swap_ab = std::nullopt);
267267

268+
void get_cutlass_moe_mm_problem_sizes_from_expert_offsets(
269+
const torch::Tensor& expert_first_token_offset,
270+
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
271+
const int64_t n, const int64_t k, const bool swap_ab);
272+
268273
void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets,
269274
torch::Tensor& problem_sizes1,
270275
torch::Tensor& problem_sizes2,

csrc/quantization/w8a8/cutlass/moe/moe_data.cu

Lines changed: 96 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
#include <c10/cuda/CUDAGuard.h>
44
#include <torch/all.h>
55

6+
#include "dispatch_utils.h"
7+
68
#include <iostream>
79

810
constexpr uint64_t THREADS_PER_EXPERT = 512;
@@ -114,22 +116,17 @@ inline void launch_compute_problem_sizes(const torch::Tensor& topk_ids,
114116
const bool swap_ab) {
115117
int num_threads = min(THREADS_PER_EXPERT, topk_ids.numel());
116118

117-
const int32_t* topk_ptr = static_cast<const int32_t*>(topk_ids.data_ptr());
118-
int32_t* ps1_ptr = static_cast<int32_t*>(problem_sizes1.data_ptr());
119-
int32_t* ps2_ptr = static_cast<int32_t*>(problem_sizes2.data_ptr());
120-
int32_t* atomic_ptr = static_cast<int32_t*>(atomic_buffer.data_ptr());
119+
auto const* topk_ptr = topk_ids.data_ptr<int32_t>();
120+
auto* ps1_ptr = problem_sizes1.data_ptr<int32_t>();
121+
auto* ps2_ptr = problem_sizes2.data_ptr<int32_t>();
122+
auto* atomic_ptr = atomic_buffer.data_ptr<int32_t>();
121123

122-
if (swap_ab) {
123-
compute_problem_sizes<true><<<num_experts, num_threads, 0, stream>>>(
124+
VLLM_DISPATCH_BOOL(swap_ab, SwapAB, [&] {
125+
compute_problem_sizes<SwapAB><<<num_experts, num_threads, 0, stream>>>(
124126
topk_ptr, ps1_ptr, ps2_ptr, atomic_ptr,
125127
static_cast<int>(topk_ids.numel()), static_cast<int>(n),
126128
static_cast<int>(k));
127-
} else {
128-
compute_problem_sizes<false><<<num_experts, num_threads, 0, stream>>>(
129-
topk_ptr, ps1_ptr, ps2_ptr, atomic_ptr,
130-
static_cast<int>(topk_ids.numel()), static_cast<int>(n),
131-
static_cast<int>(k));
132-
}
129+
});
133130
}
134131
} // namespace
135132

@@ -153,6 +150,93 @@ void get_cutlass_moe_mm_problem_sizes_caller(
153150
may_swap_ab);
154151
}
155152

153+
template <bool SWAP_AB>
154+
__global__ void compute_problem_sizes_from_expert_offsets(
155+
const int64_t* __restrict__ expert_first_token_offset,
156+
int32_t* __restrict__ problem_sizes1, int32_t* __restrict__ problem_sizes2,
157+
const int num_experts, const int n, const int k) {
158+
int const expert_id = blockIdx.x * blockDim.x + threadIdx.x;
159+
if (expert_id >= num_experts) {
160+
return;
161+
}
162+
163+
int64_t const m64 = expert_first_token_offset[expert_id + 1] -
164+
expert_first_token_offset[expert_id];
165+
int32_t const m = static_cast<int32_t>(m64);
166+
167+
int32_t* ps1 = problem_sizes1 + expert_id * 3;
168+
int32_t* ps2 = problem_sizes2 + expert_id * 3;
169+
170+
if constexpr (!SWAP_AB) {
171+
// [M, 2*N, K]
172+
ps1[0] = m;
173+
ps1[1] = 2 * n;
174+
ps1[2] = k;
175+
// [M, K, N]
176+
ps2[0] = m;
177+
ps2[1] = k;
178+
ps2[2] = n;
179+
} else {
180+
// swap logical M/N in the problem shape
181+
// [2*N, M, K]
182+
ps1[0] = 2 * n;
183+
ps1[1] = m;
184+
ps1[2] = k;
185+
// [K, M, N]
186+
ps2[0] = k;
187+
ps2[1] = m;
188+
ps2[2] = n;
189+
}
190+
}
191+
192+
void get_cutlass_moe_mm_problem_sizes_from_expert_offsets_caller(
193+
const torch::Tensor& expert_first_token_offset,
194+
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
195+
const int64_t n, const int64_t k, const bool swap_ab) {
196+
TORCH_CHECK(expert_first_token_offset.is_cuda(),
197+
"expert_first_token_offset must be a CUDA tensor");
198+
TORCH_CHECK(expert_first_token_offset.dtype() == torch::kInt64,
199+
"expert_first_token_offset must be int64");
200+
201+
TORCH_CHECK(problem_sizes1.is_cuda() && problem_sizes2.is_cuda(),
202+
"problem_sizes must be CUDA tensors");
203+
TORCH_CHECK(problem_sizes1.dtype() == torch::kInt32 &&
204+
problem_sizes2.dtype() == torch::kInt32,
205+
"problem_sizes must be int32");
206+
TORCH_CHECK(problem_sizes1.is_contiguous() && problem_sizes2.is_contiguous(),
207+
"problem_sizes must be contiguous");
208+
TORCH_CHECK(problem_sizes1.dim() == 2 && problem_sizes2.dim() == 2,
209+
"problem_sizes must be 2D tensors");
210+
TORCH_CHECK(problem_sizes1.size(1) == 3 && problem_sizes2.size(1) == 3,
211+
"problem_sizes second dim must be 3");
212+
TORCH_CHECK(problem_sizes1.sizes() == problem_sizes2.sizes(),
213+
"problem_sizes1 and problem_sizes2 must have same shape");
214+
215+
int64_t const num_experts64 = problem_sizes1.size(0);
216+
TORCH_CHECK(expert_first_token_offset.numel() == num_experts64 + 1,
217+
"expert_first_token_offset must have num_experts + 1 elements");
218+
TORCH_CHECK(num_experts64 <= INT32_MAX, "num_experts must fit in int32");
219+
TORCH_CHECK(n <= INT32_MAX && k <= INT32_MAX, "n and k must fit in int32");
220+
221+
int const num_experts = static_cast<int>(num_experts64);
222+
auto stream = at::cuda::getCurrentCUDAStream(
223+
expert_first_token_offset.device().index());
224+
225+
int const threads = (num_experts < 256) ? num_experts : 256;
226+
int const blocks = (num_experts + threads - 1) / threads;
227+
228+
auto const* offsets_ptr = expert_first_token_offset.data_ptr<int64_t>();
229+
auto* ps1_ptr = problem_sizes1.data_ptr<int32_t>();
230+
auto* ps2_ptr = problem_sizes2.data_ptr<int32_t>();
231+
232+
VLLM_DISPATCH_BOOL(swap_ab, SwapAB, [&] {
233+
compute_problem_sizes_from_expert_offsets<SwapAB>
234+
<<<blocks, threads, 0, stream>>>(offsets_ptr, ps1_ptr, ps2_ptr,
235+
num_experts, static_cast<int>(n),
236+
static_cast<int>(k));
237+
});
238+
}
239+
156240
void get_cutlass_moe_mm_data_caller(
157241
const torch::Tensor& topk_ids, torch::Tensor& expert_offsets,
158242
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,

csrc/quantization/w8a8/cutlass/scaled_mm_entry.cu

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,11 @@ void get_cutlass_moe_mm_problem_sizes_caller(
8383
const int64_t k, const std::optional<torch::Tensor>& blockscale_offsets,
8484
std::optional<bool> force_swap_ab = std::nullopt);
8585

86+
void get_cutlass_moe_mm_problem_sizes_from_expert_offsets_caller(
87+
const torch::Tensor& expert_first_token_offset,
88+
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
89+
const int64_t n, const int64_t k, const bool swap_ab);
90+
8691
void get_cutlass_pplx_moe_mm_data_caller(torch::Tensor& expert_offsets,
8792
torch::Tensor& problem_sizes1,
8893
torch::Tensor& problem_sizes2,
@@ -322,6 +327,25 @@ void get_cutlass_moe_mm_problem_sizes(
322327
version_num, ". Required capability: 90, 100, or 120");
323328
}
324329

330+
void get_cutlass_moe_mm_problem_sizes_from_expert_offsets(
331+
const torch::Tensor& expert_first_token_offset,
332+
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
333+
const int64_t n, const int64_t k, const bool swap_ab) {
334+
int32_t version_num = get_sm_version_num();
335+
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
336+
(defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100) || \
337+
(defined ENABLE_CUTLASS_MOE_SM120 && ENABLE_CUTLASS_MOE_SM120)
338+
get_cutlass_moe_mm_problem_sizes_from_expert_offsets_caller(
339+
expert_first_token_offset, problem_sizes1, problem_sizes2, n, k, swap_ab);
340+
return;
341+
#endif
342+
TORCH_CHECK_NOT_IMPLEMENTED(
343+
false,
344+
"No compiled get_cutlass_moe_mm_problem_sizes_from_expert_offsets: "
345+
"no cutlass_scaled_mm kernel for CUDA device capability: ",
346+
version_num, ". Required capability: 90, 100, or 120");
347+
}
348+
325349
void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets,
326350
torch::Tensor& problem_sizes1,
327351
torch::Tensor& problem_sizes2,

csrc/torch_bindings.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -487,6 +487,17 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
487487
ops.impl("get_cutlass_moe_mm_problem_sizes", torch::kCUDA,
488488
&get_cutlass_moe_mm_problem_sizes);
489489

490+
// compute per-expert problem sizes from expert_first_token_offset
491+
// produced by vLLM's moe_permute kernel
492+
ops.def(
493+
"get_cutlass_moe_mm_problem_sizes_from_expert_offsets("
494+
" Tensor expert_first_token_offset, "
495+
" Tensor! problem_sizes1, "
496+
" Tensor! problem_sizes2, "
497+
" int n, int k, bool swap_ab) -> ()");
498+
ops.impl("get_cutlass_moe_mm_problem_sizes_from_expert_offsets", torch::kCUDA,
499+
&get_cutlass_moe_mm_problem_sizes_from_expert_offsets);
500+
490501
// A function that computes data required to run fused MoE with w8a8 grouped
491502
// GEMM and PPLX. It takes expert_num_tokens and non_zero_expert_idxs
492503
// as an input, and computes expert_offsets (token start indices of each

vllm/_custom_ops.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1075,6 +1075,25 @@ def get_cutlass_moe_mm_problem_sizes(
10751075
)
10761076

10771077

1078+
def get_cutlass_moe_mm_problem_sizes_from_expert_offsets(
1079+
expert_first_token_offset: torch.Tensor,
1080+
problem_sizes1: torch.Tensor,
1081+
problem_sizes2: torch.Tensor,
1082+
n: int,
1083+
k: int,
1084+
swap_ab: bool,
1085+
):
1086+
"""Compute per-expert (M, N, K) problem sizes from expert_first_token_offset"""
1087+
return torch.ops._C.get_cutlass_moe_mm_problem_sizes_from_expert_offsets(
1088+
expert_first_token_offset,
1089+
problem_sizes1,
1090+
problem_sizes2,
1091+
n,
1092+
k,
1093+
swap_ab,
1094+
)
1095+
1096+
10781097
def shuffle_rows(input_tensor: torch.Tensor, dst2src_map: torch.Tensor):
10791098
"""
10801099
Shuffle and expand the input tensor according to the dst2src_map and store the result in output_tensor.

vllm/model_executor/layers/fused_moe/cutlass_moe.py

Lines changed: 17 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -108,15 +108,7 @@ def run_cutlass_moe_fp8(
108108
assert global_num_experts != -1
109109
assert a1q_scale is not None
110110

111-
if expert_map is not None:
112-
"Translate info from expert_map to topk_ids"
113-
local_topk_ids = torch.where(
114-
expert_map[topk_ids] != -1, expert_map[topk_ids], -1
115-
)
116-
else:
117-
local_topk_ids = topk_ids
118-
119-
topk = local_topk_ids.size(1)
111+
topk = topk_ids.size(1)
120112
local_E = w1.size(0)
121113

122114
if use_batched_format:
@@ -164,12 +156,8 @@ def run_cutlass_moe_fp8(
164156
# during offset calculations
165157
expert_offsets = expert_offsets.to(torch.int64)
166158
else:
167-
problem_sizes1 = torch.empty(
168-
(global_num_experts, 3), dtype=torch.int32, device=device
169-
)
170-
problem_sizes2 = torch.empty(
171-
(global_num_experts, 3), dtype=torch.int32, device=device
172-
)
159+
problem_sizes1 = torch.empty((local_E, 3), dtype=torch.int32, device=device)
160+
problem_sizes2 = torch.empty((local_E, 3), dtype=torch.int32, device=device)
173161

174162
num_expert = global_num_experts if expert_map is None else expert_map.size(0)
175163
# permuted a1q reuses workspace2
@@ -182,11 +170,12 @@ def run_cutlass_moe_fp8(
182170
expert_map,
183171
permuted_hidden_states=a1q_perm,
184172
)
185-
expert_offsets = expert_first_token_offset[:-1]
186-
187-
ops.get_cutlass_moe_mm_problem_sizes(
188-
local_topk_ids, problem_sizes1, problem_sizes2, global_num_experts, N, K
173+
# swap_ab is a CUTLASS grouped-GEMM optimization (M <= 64 reduces padding).
174+
swap_ab = a1q.size(0) <= 64
175+
ops.get_cutlass_moe_mm_problem_sizes_from_expert_offsets(
176+
expert_first_token_offset, problem_sizes1, problem_sizes2, N, K, swap_ab
189177
)
178+
expert_offsets = expert_first_token_offset[:-1]
190179

191180
if not per_act_token and (expert_map is not None or use_batched_format):
192181
# this is necessary to avoid imprecise scale calculation caused by
@@ -240,9 +229,7 @@ def run_cutlass_moe_fp8(
240229
permuted_hidden_states=mm2_out,
241230
topk_weights=topk_weights,
242231
inv_permuted_idx=inv_perm,
243-
expert_first_token_offset=(
244-
expert_first_token_offset if expert_map is not None else None
245-
),
232+
expert_first_token_offset=expert_first_token_offset,
246233
)
247234

248235

@@ -772,15 +759,7 @@ def run_cutlass_moe_w4a8_fp8(
772759
f"w1 hidden size mismatch: got {w1.size(2) * 8}, expected {K=}"
773760
)
774761

775-
# Translate info from expert_map to topk_ids
776-
if expert_map is not None:
777-
local_topk_ids = torch.where(
778-
expert_map[topk_ids] != -1, expert_map[topk_ids], -1
779-
)
780-
else:
781-
local_topk_ids = topk_ids
782-
783-
topk = local_topk_ids.size(1)
762+
topk = topk_ids.size(1)
784763
a1q_perm = _resize_cache(workspace2.view(dtype=torch.float8_e4m3fn), (M * topk, K))
785764
mm1_out = _resize_cache(workspace13, (M * topk, N * 2))
786765
act_out = _resize_cache(workspace2, (M * topk, N))
@@ -790,12 +769,8 @@ def run_cutlass_moe_w4a8_fp8(
790769
)
791770
mm2_out = _resize_cache(workspace2, (M * topk, K))
792771

793-
problem_sizes1 = torch.empty(
794-
(global_num_experts, 3), dtype=torch.int32, device=device
795-
)
796-
problem_sizes2 = torch.empty(
797-
(global_num_experts, 3), dtype=torch.int32, device=device
798-
)
772+
problem_sizes1 = torch.empty((local_E, 3), dtype=torch.int32, device=device)
773+
problem_sizes2 = torch.empty((local_E, 3), dtype=torch.int32, device=device)
799774

800775
num_expert = global_num_experts if expert_map is None else expert_map.size(0)
801776
# permuted a1q reuses workspace2
@@ -808,18 +783,11 @@ def run_cutlass_moe_w4a8_fp8(
808783
expert_map,
809784
permuted_hidden_states=a1q_perm,
810785
)
811-
expert_offsets = expert_first_token_offset[:-1]
812-
813-
# For RS gemm SwapAB is always enabled (swap logical M, N in the problem shape)
814-
ops.get_cutlass_moe_mm_problem_sizes(
815-
local_topk_ids,
816-
problem_sizes1,
817-
problem_sizes2,
818-
global_num_experts,
819-
N,
820-
K,
821-
force_swap_ab=True,
786+
# for RS gemm SwapAB is always enabled (swap logical M, N in the problem shape).
787+
ops.get_cutlass_moe_mm_problem_sizes_from_expert_offsets(
788+
expert_first_token_offset, problem_sizes1, problem_sizes2, N, K, True
822789
)
790+
expert_offsets = expert_first_token_offset[:-1]
823791

824792
ops.cutlass_w4a8_moe_mm(
825793
mm1_out,
@@ -866,9 +834,7 @@ def run_cutlass_moe_w4a8_fp8(
866834
permuted_hidden_states=mm2_out,
867835
topk_weights=topk_weights,
868836
inv_permuted_idx=inv_perm,
869-
expert_first_token_offset=(
870-
expert_first_token_offset if expert_map is not None else None
871-
),
837+
expert_first_token_offset=expert_first_token_offset,
872838
)
873839

874840

0 commit comments

Comments
 (0)