1
1
#include < torch/all.h>
2
2
#include < ATen/cuda/CUDAContext.h>
3
3
#include < c10/cuda/CUDAGuard.h>
4
+ #include < cub/cub.cuh>
4
5
5
6
#include < ATen/ATen.h>
6
7
#include < ATen/cuda/Atomic.cuh>
@@ -19,9 +20,14 @@ __global__ void moe_align_block_size_kernel(
19
20
int32_t * __restrict__ sorted_token_ids, int32_t * __restrict__ expert_ids,
20
21
int32_t * __restrict__ total_tokens_post_pad, int32_t num_experts,
21
22
int32_t padded_num_experts, int32_t experts_per_warp, int32_t block_size,
22
- size_t numel, int32_t * __restrict__ cumsum) {
23
+ size_t numel, int32_t * __restrict__ cumsum, int32_t max_num_tokens_padded ) {
23
24
extern __shared__ int32_t shared_counts[];
24
25
26
+ // Initialize sorted_token_ids with numel
27
+ for (size_t it = threadIdx .x ; it < max_num_tokens_padded; it += blockDim .x ) {
28
+ sorted_token_ids[it] = numel;
29
+ }
30
+
25
31
const int warp_id = threadIdx .x / WARP_SIZE;
26
32
const int my_expert_start = warp_id * experts_per_warp;
27
33
@@ -45,18 +51,27 @@ __global__ void moe_align_block_size_kernel(
45
51
46
52
__syncthreads ();
47
53
48
- if (threadIdx .x == 0 ) {
49
- cumsum[0 ] = 0 ;
50
- for (int i = 1 ; i <= num_experts; ++i) {
51
- int expert_count = 0 ;
52
- int warp_idx = (i - 1 ) / experts_per_warp;
53
- int expert_offset = (i - 1 ) % experts_per_warp;
54
- expert_count = shared_counts[warp_idx * experts_per_warp + expert_offset];
54
+ // Compute prefix sum over token counts per expert
55
+ using BlockScan = cub::BlockScan<int32_t , 1024 >;
56
+ __shared__ typename BlockScan::TempStorage temp_storage;
55
57
56
- cumsum[i] =
57
- cumsum[i - 1 ] + CEILDIV (expert_count, block_size) * block_size;
58
- }
59
- *total_tokens_post_pad = cumsum[num_experts];
58
+ int expert_count = 0 ;
59
+ int expert_id = threadIdx .x ;
60
+ if (expert_id < num_experts) {
61
+ int warp_idx = expert_id / experts_per_warp;
62
+ int expert_offset = expert_id % experts_per_warp;
63
+ expert_count = shared_counts[warp_idx * experts_per_warp + expert_offset];
64
+ expert_count = CEILDIV (expert_count, block_size) * block_size;
65
+ }
66
+
67
+ int cumsum_val;
68
+ BlockScan (temp_storage).ExclusiveSum (expert_count, cumsum_val);
69
+ if (expert_id <= num_experts) {
70
+ cumsum[expert_id] = cumsum_val;
71
+ }
72
+
73
+ if (expert_id == num_experts) {
74
+ *total_tokens_post_pad = cumsum_val;
60
75
}
61
76
62
77
__syncthreads ();
@@ -67,6 +82,13 @@ __global__ void moe_align_block_size_kernel(
67
82
expert_ids[i / block_size] = threadIdx .x ;
68
83
}
69
84
}
85
+
86
+ // Fill remaining expert_ids with 0
87
+ const size_t fill_start_idx = cumsum[num_experts] / block_size + threadIdx .x ;
88
+ const size_t expert_ids_size = CEILDIV (max_num_tokens_padded, block_size);
89
+ for (size_t i = fill_start_idx; i < expert_ids_size; i += blockDim .x ) {
90
+ expert_ids[i] = 0 ;
91
+ }
70
92
}
71
93
72
94
template <typename scalar_t >
@@ -105,7 +127,12 @@ __global__ void moe_align_block_size_small_batch_expert_kernel(
105
127
const scalar_t * __restrict__ topk_ids,
106
128
int32_t * __restrict__ sorted_token_ids, int32_t * __restrict__ expert_ids,
107
129
int32_t * __restrict__ total_tokens_post_pad, int32_t num_experts,
108
- int32_t block_size, size_t numel) {
130
+ int32_t block_size, size_t numel, int32_t max_num_tokens_padded) {
131
+ // Initialize sorted_token_ids with numel
132
+ for (size_t it = threadIdx .x ; it < max_num_tokens_padded; it += blockDim .x ) {
133
+ sorted_token_ids[it] = numel;
134
+ }
135
+
109
136
const size_t tid = threadIdx .x ;
110
137
const size_t stride = blockDim .x ;
111
138
@@ -153,6 +180,13 @@ __global__ void moe_align_block_size_small_batch_expert_kernel(
153
180
}
154
181
}
155
182
183
+ // Fill remaining expert_ids with 0
184
+ const size_t fill_start_idx = cumsum[num_experts] / block_size + threadIdx .x ;
185
+ const size_t expert_ids_size = CEILDIV (max_num_tokens_padded, block_size);
186
+ for (size_t i = fill_start_idx; i < expert_ids_size; i += blockDim .x ) {
187
+ expert_ids[i] = 0 ;
188
+ }
189
+
156
190
for (size_t i = tid; i < numel; i += stride) {
157
191
int32_t expert_id = topk_ids[i];
158
192
int32_t rank_post_pad =
@@ -179,13 +213,17 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
179
213
int threads = 1024 ;
180
214
threads = ((threads + WARP_SIZE - 1 ) / WARP_SIZE) * WARP_SIZE;
181
215
216
+ // BlockScan uses 1024 threads and assigns one thread per expert.
217
+ TORCH_CHECK (padded_num_experts < 1024 ,
218
+ " padded_num_experts must be less than 1024" );
219
+
182
220
VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES (
183
221
topk_ids.scalar_type (), " moe_align_block_size_kernel" , [&] {
184
222
// calc needed amount of shared mem for `cumsum` tensors
185
223
auto options_int =
186
224
torch::TensorOptions ().dtype (torch::kInt ).device (topk_ids.device ());
187
225
torch::Tensor cumsum_buffer =
188
- torch::zeros ({num_experts + 1 }, options_int);
226
+ torch::empty ({num_experts + 1 }, options_int);
189
227
bool small_batch_expert_mode =
190
228
(topk_ids.numel () < 1024 ) && (num_experts <= 64 );
191
229
@@ -203,7 +241,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
203
241
sorted_token_ids.data_ptr <int32_t >(),
204
242
experts_ids.data_ptr <int32_t >(),
205
243
num_tokens_post_pad.data_ptr <int32_t >(), num_experts, block_size,
206
- topk_ids.numel ());
244
+ topk_ids.numel (), sorted_token_ids. size ( 0 ) );
207
245
} else {
208
246
auto align_kernel = vllm::moe::moe_align_block_size_kernel<scalar_t >;
209
247
@@ -217,7 +255,8 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
217
255
experts_ids.data_ptr <int32_t >(),
218
256
num_tokens_post_pad.data_ptr <int32_t >(), num_experts,
219
257
padded_num_experts, experts_per_warp, block_size,
220
- topk_ids.numel (), cumsum_buffer.data_ptr <int32_t >());
258
+ topk_ids.numel (), cumsum_buffer.data_ptr <int32_t >(),
259
+ sorted_token_ids.size (0 ));
221
260
222
261
const int block_threads = std::min (256 , (int )threads);
223
262
const int num_blocks =
0 commit comments