11#include < torch/all.h>
22#include < ATen/cuda/CUDAContext.h>
33#include < c10/cuda/CUDAGuard.h>
4+ #include < cub/cub.cuh>
45
56#include < ATen/ATen.h>
67#include < ATen/cuda/Atomic.cuh>
@@ -19,9 +20,14 @@ __global__ void moe_align_block_size_kernel(
1920 int32_t * __restrict__ sorted_token_ids, int32_t * __restrict__ expert_ids,
2021 int32_t * __restrict__ total_tokens_post_pad, int32_t num_experts,
2122 int32_t padded_num_experts, int32_t experts_per_warp, int32_t block_size,
22- size_t numel, int32_t * __restrict__ cumsum) {
23+ size_t numel, int32_t * __restrict__ cumsum, int32_t max_num_tokens_padded ) {
2324 extern __shared__ int32_t shared_counts[];
2425
26+ // Initialize sorted_token_ids with numel
27+ for (size_t it = threadIdx .x ; it < max_num_tokens_padded; it += blockDim .x ) {
28+ sorted_token_ids[it] = numel;
29+ }
30+
2531 const int warp_id = threadIdx .x / WARP_SIZE;
2632 const int my_expert_start = warp_id * experts_per_warp;
2733
@@ -45,18 +51,27 @@ __global__ void moe_align_block_size_kernel(
4551
4652 __syncthreads ();
4753
48- if (threadIdx .x == 0 ) {
49- cumsum[0 ] = 0 ;
50- for (int i = 1 ; i <= num_experts; ++i) {
51- int expert_count = 0 ;
52- int warp_idx = (i - 1 ) / experts_per_warp;
53- int expert_offset = (i - 1 ) % experts_per_warp;
54- expert_count = shared_counts[warp_idx * experts_per_warp + expert_offset];
54+ // Compute prefix sum over token counts per expert
55+ using BlockScan = cub::BlockScan<int32_t , 1024 >;
56+ __shared__ typename BlockScan::TempStorage temp_storage;
5557
56- cumsum[i] =
57- cumsum[i - 1 ] + CEILDIV (expert_count, block_size) * block_size;
58- }
59- *total_tokens_post_pad = cumsum[num_experts];
58+ int expert_count = 0 ;
59+ int expert_id = threadIdx .x ;
60+ if (expert_id < num_experts) {
61+ int warp_idx = expert_id / experts_per_warp;
62+ int expert_offset = expert_id % experts_per_warp;
63+ expert_count = shared_counts[warp_idx * experts_per_warp + expert_offset];
64+ expert_count = CEILDIV (expert_count, block_size) * block_size;
65+ }
66+
67+ int cumsum_val;
68+ BlockScan (temp_storage).ExclusiveSum (expert_count, cumsum_val);
69+ if (expert_id <= num_experts) {
70+ cumsum[expert_id] = cumsum_val;
71+ }
72+
73+ if (expert_id == num_experts) {
74+ *total_tokens_post_pad = cumsum_val;
6075 }
6176
6277 __syncthreads ();
@@ -67,6 +82,13 @@ __global__ void moe_align_block_size_kernel(
6782 expert_ids[i / block_size] = threadIdx .x ;
6883 }
6984 }
85+
86+ // Fill remaining expert_ids with 0
87+ const size_t fill_start_idx = cumsum[num_experts] / block_size + threadIdx .x ;
88+ const size_t expert_ids_size = CEILDIV (max_num_tokens_padded, block_size);
89+ for (size_t i = fill_start_idx; i < expert_ids_size; i += blockDim .x ) {
90+ expert_ids[i] = 0 ;
91+ }
7092}
7193
7294template <typename scalar_t >
@@ -105,7 +127,12 @@ __global__ void moe_align_block_size_small_batch_expert_kernel(
105127 const scalar_t * __restrict__ topk_ids,
106128 int32_t * __restrict__ sorted_token_ids, int32_t * __restrict__ expert_ids,
107129 int32_t * __restrict__ total_tokens_post_pad, int32_t num_experts,
108- int32_t block_size, size_t numel) {
130+ int32_t block_size, size_t numel, int32_t max_num_tokens_padded) {
131+ // Initialize sorted_token_ids with numel
132+ for (size_t it = threadIdx .x ; it < max_num_tokens_padded; it += blockDim .x ) {
133+ sorted_token_ids[it] = numel;
134+ }
135+
109136 const size_t tid = threadIdx .x ;
110137 const size_t stride = blockDim .x ;
111138
@@ -153,6 +180,13 @@ __global__ void moe_align_block_size_small_batch_expert_kernel(
153180 }
154181 }
155182
183+ // Fill remaining expert_ids with 0
184+ const size_t fill_start_idx = cumsum[num_experts] / block_size + threadIdx .x ;
185+ const size_t expert_ids_size = CEILDIV (max_num_tokens_padded, block_size);
186+ for (size_t i = fill_start_idx; i < expert_ids_size; i += blockDim .x ) {
187+ expert_ids[i] = 0 ;
188+ }
189+
156190 for (size_t i = tid; i < numel; i += stride) {
157191 int32_t expert_id = topk_ids[i];
158192 int32_t rank_post_pad =
@@ -179,13 +213,17 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
179213 int threads = 1024 ;
180214 threads = ((threads + WARP_SIZE - 1 ) / WARP_SIZE) * WARP_SIZE;
181215
216+ // BlockScan uses 1024 threads and assigns one thread per expert.
217+ TORCH_CHECK (padded_num_experts < 1024 ,
218+ " padded_num_experts must be less than 1024" );
219+
182220 VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES (
183221 topk_ids.scalar_type (), " moe_align_block_size_kernel" , [&] {
184222 // calc needed amount of shared mem for `cumsum` tensors
185223 auto options_int =
186224 torch::TensorOptions ().dtype (torch::kInt ).device (topk_ids.device ());
187225 torch::Tensor cumsum_buffer =
188- torch::zeros ({num_experts + 1 }, options_int);
226+ torch::empty ({num_experts + 1 }, options_int);
189227 bool small_batch_expert_mode =
190228 (topk_ids.numel () < 1024 ) && (num_experts <= 64 );
191229
@@ -203,7 +241,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
203241 sorted_token_ids.data_ptr <int32_t >(),
204242 experts_ids.data_ptr <int32_t >(),
205243 num_tokens_post_pad.data_ptr <int32_t >(), num_experts, block_size,
206- topk_ids.numel ());
244+ topk_ids.numel (), sorted_token_ids. size ( 0 ) );
207245 } else {
208246 auto align_kernel = vllm::moe::moe_align_block_size_kernel<scalar_t >;
209247
@@ -217,7 +255,8 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
217255 experts_ids.data_ptr <int32_t >(),
218256 num_tokens_post_pad.data_ptr <int32_t >(), num_experts,
219257 padded_num_experts, experts_per_warp, block_size,
220- topk_ids.numel (), cumsum_buffer.data_ptr <int32_t >());
258+ topk_ids.numel (), cumsum_buffer.data_ptr <int32_t >(),
259+ sorted_token_ids.size (0 ));
221260
222261 const int block_threads = std::min (256 , (int )threads);
223262 const int num_blocks =
0 commit comments