Skip to content

Commit 0ec82ed

Browse files
authored
[perf] Speed up align sum kernels (vllm-project#21079)
Signed-off-by: Himanshu Jaju <[email protected]>
1 parent 005ae9b commit 0ec82ed

File tree

3 files changed

+60
-25
lines changed

3 files changed

+60
-25
lines changed

benchmarks/kernels/benchmark_moe_align_block_size.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,15 +33,13 @@ def check_correctness(num_tokens, num_experts=256, block_size=256, topk=8):
3333
sorted_ids_triton = torch.empty(
3434
(max_num_tokens_padded,), dtype=torch.int32, device="cuda"
3535
)
36-
sorted_ids_triton.fill_(topk_ids.numel()) # fill with sentinel value
37-
expert_ids_triton = torch.zeros(
36+
expert_ids_triton = torch.empty(
3837
(max_num_tokens_padded // block_size,), dtype=torch.int32, device="cuda"
3938
)
4039
num_tokens_post_pad_triton = torch.empty((1,), dtype=torch.int32, device="cuda")
4140

4241
sorted_ids_vllm = torch.empty_like(sorted_ids_triton)
43-
sorted_ids_vllm.fill_(topk_ids.numel())
44-
expert_ids_vllm = torch.zeros_like(expert_ids_triton)
42+
expert_ids_vllm = torch.empty_like(expert_ids_triton)
4543
num_tokens_post_pad_vllm = torch.empty_like(num_tokens_post_pad_triton)
4644

4745
# 2. run implementations
@@ -102,7 +100,6 @@ def benchmark(num_tokens, num_experts, topk, provider):
102100

103101
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
104102
sorted_ids = torch.empty((max_num_tokens_padded,), dtype=torch.int32, device="cuda")
105-
sorted_ids.fill_(topk_ids.numel())
106103
max_num_m_blocks = max_num_tokens_padded // block_size
107104
expert_ids = torch.empty((max_num_m_blocks,), dtype=torch.int32, device="cuda")
108105
num_tokens_post_pad = torch.empty((1,), dtype=torch.int32, device="cuda")

csrc/moe/moe_align_sum_kernels.cu

Lines changed: 55 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include <torch/all.h>
22
#include <ATen/cuda/CUDAContext.h>
33
#include <c10/cuda/CUDAGuard.h>
4+
#include <cub/cub.cuh>
45

56
#include <ATen/ATen.h>
67
#include <ATen/cuda/Atomic.cuh>
@@ -19,9 +20,14 @@ __global__ void moe_align_block_size_kernel(
1920
int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids,
2021
int32_t* __restrict__ total_tokens_post_pad, int32_t num_experts,
2122
int32_t padded_num_experts, int32_t experts_per_warp, int32_t block_size,
22-
size_t numel, int32_t* __restrict__ cumsum) {
23+
size_t numel, int32_t* __restrict__ cumsum, int32_t max_num_tokens_padded) {
2324
extern __shared__ int32_t shared_counts[];
2425

26+
// Initialize sorted_token_ids with numel
27+
for (size_t it = threadIdx.x; it < max_num_tokens_padded; it += blockDim.x) {
28+
sorted_token_ids[it] = numel;
29+
}
30+
2531
const int warp_id = threadIdx.x / WARP_SIZE;
2632
const int my_expert_start = warp_id * experts_per_warp;
2733

@@ -45,18 +51,27 @@ __global__ void moe_align_block_size_kernel(
4551

4652
__syncthreads();
4753

48-
if (threadIdx.x == 0) {
49-
cumsum[0] = 0;
50-
for (int i = 1; i <= num_experts; ++i) {
51-
int expert_count = 0;
52-
int warp_idx = (i - 1) / experts_per_warp;
53-
int expert_offset = (i - 1) % experts_per_warp;
54-
expert_count = shared_counts[warp_idx * experts_per_warp + expert_offset];
54+
// Compute prefix sum over token counts per expert
55+
using BlockScan = cub::BlockScan<int32_t, 1024>;
56+
__shared__ typename BlockScan::TempStorage temp_storage;
5557

56-
cumsum[i] =
57-
cumsum[i - 1] + CEILDIV(expert_count, block_size) * block_size;
58-
}
59-
*total_tokens_post_pad = cumsum[num_experts];
58+
int expert_count = 0;
59+
int expert_id = threadIdx.x;
60+
if (expert_id < num_experts) {
61+
int warp_idx = expert_id / experts_per_warp;
62+
int expert_offset = expert_id % experts_per_warp;
63+
expert_count = shared_counts[warp_idx * experts_per_warp + expert_offset];
64+
expert_count = CEILDIV(expert_count, block_size) * block_size;
65+
}
66+
67+
int cumsum_val;
68+
BlockScan(temp_storage).ExclusiveSum(expert_count, cumsum_val);
69+
if (expert_id <= num_experts) {
70+
cumsum[expert_id] = cumsum_val;
71+
}
72+
73+
if (expert_id == num_experts) {
74+
*total_tokens_post_pad = cumsum_val;
6075
}
6176

6277
__syncthreads();
@@ -67,6 +82,13 @@ __global__ void moe_align_block_size_kernel(
6782
expert_ids[i / block_size] = threadIdx.x;
6883
}
6984
}
85+
86+
// Fill remaining expert_ids with 0
87+
const size_t fill_start_idx = cumsum[num_experts] / block_size + threadIdx.x;
88+
const size_t expert_ids_size = CEILDIV(max_num_tokens_padded, block_size);
89+
for (size_t i = fill_start_idx; i < expert_ids_size; i += blockDim.x) {
90+
expert_ids[i] = 0;
91+
}
7092
}
7193

7294
template <typename scalar_t>
@@ -105,7 +127,12 @@ __global__ void moe_align_block_size_small_batch_expert_kernel(
105127
const scalar_t* __restrict__ topk_ids,
106128
int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids,
107129
int32_t* __restrict__ total_tokens_post_pad, int32_t num_experts,
108-
int32_t block_size, size_t numel) {
130+
int32_t block_size, size_t numel, int32_t max_num_tokens_padded) {
131+
// Initialize sorted_token_ids with numel
132+
for (size_t it = threadIdx.x; it < max_num_tokens_padded; it += blockDim.x) {
133+
sorted_token_ids[it] = numel;
134+
}
135+
109136
const size_t tid = threadIdx.x;
110137
const size_t stride = blockDim.x;
111138

@@ -153,6 +180,13 @@ __global__ void moe_align_block_size_small_batch_expert_kernel(
153180
}
154181
}
155182

183+
// Fill remaining expert_ids with 0
184+
const size_t fill_start_idx = cumsum[num_experts] / block_size + threadIdx.x;
185+
const size_t expert_ids_size = CEILDIV(max_num_tokens_padded, block_size);
186+
for (size_t i = fill_start_idx; i < expert_ids_size; i += blockDim.x) {
187+
expert_ids[i] = 0;
188+
}
189+
156190
for (size_t i = tid; i < numel; i += stride) {
157191
int32_t expert_id = topk_ids[i];
158192
int32_t rank_post_pad =
@@ -179,13 +213,17 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
179213
int threads = 1024;
180214
threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
181215

216+
// BlockScan uses 1024 threads and assigns one thread per expert.
217+
TORCH_CHECK(padded_num_experts < 1024,
218+
"padded_num_experts must be less than 1024");
219+
182220
VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(
183221
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
184222
// calc needed amount of shared mem for `cumsum` tensors
185223
auto options_int =
186224
torch::TensorOptions().dtype(torch::kInt).device(topk_ids.device());
187225
torch::Tensor cumsum_buffer =
188-
torch::zeros({num_experts + 1}, options_int);
226+
torch::empty({num_experts + 1}, options_int);
189227
bool small_batch_expert_mode =
190228
(topk_ids.numel() < 1024) && (num_experts <= 64);
191229

@@ -203,7 +241,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
203241
sorted_token_ids.data_ptr<int32_t>(),
204242
experts_ids.data_ptr<int32_t>(),
205243
num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
206-
topk_ids.numel());
244+
topk_ids.numel(), sorted_token_ids.size(0));
207245
} else {
208246
auto align_kernel = vllm::moe::moe_align_block_size_kernel<scalar_t>;
209247

@@ -217,7 +255,8 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
217255
experts_ids.data_ptr<int32_t>(),
218256
num_tokens_post_pad.data_ptr<int32_t>(), num_experts,
219257
padded_num_experts, experts_per_warp, block_size,
220-
topk_ids.numel(), cumsum_buffer.data_ptr<int32_t>());
258+
topk_ids.numel(), cumsum_buffer.data_ptr<int32_t>(),
259+
sorted_token_ids.size(0));
221260

222261
const int block_threads = std::min(256, (int)threads);
223262
const int num_blocks =

vllm/model_executor/layers/fused_moe/moe_align_block_size.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,8 @@ def moe_align_block_size_triton(
111111
dtype=torch.int32,
112112
device=topk_ids.device)
113113
tokens_per_thread = cdiv(numel, num_experts)
114+
sorted_token_ids.fill_(numel)
115+
expert_ids.zero_()
114116

115117
moe_align_block_size_stage1[grid](
116118
topk_ids,
@@ -205,11 +207,8 @@ def moe_align_block_size(
205207
sorted_ids = torch.empty((max_num_tokens_padded, ),
206208
dtype=torch.int32,
207209
device=topk_ids.device)
208-
sorted_ids.fill_(topk_ids.numel())
209210
max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
210-
# Expert ids must be zeroed out to prevent index out of bounds error while
211-
# mapping global expert ids to local expert ids in expert parallelism.
212-
expert_ids = torch.zeros((max_num_m_blocks, ),
211+
expert_ids = torch.empty((max_num_m_blocks, ),
213212
dtype=torch.int32,
214213
device=topk_ids.device)
215214
num_tokens_post_pad = torch.empty((1),

0 commit comments

Comments
 (0)