@@ -55,11 +55,11 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
5555 __syncthreads ();
5656
5757 // For each expert we accumulate the token counts from the different threads.
58- for ( int eid = threadIdx .x ; eid < num_experts; eid += blockDim . x ) {
59- tokens_cnts[index (num_experts, 0 , eid )] = 0 ;
58+ if ( threadIdx .x < num_experts) {
59+ tokens_cnts[index (num_experts, 0 , threadIdx . x )] = 0 ;
6060 for (int i = 1 ; i <= blockDim .x ; ++i) {
61- tokens_cnts[index (num_experts, i, eid )] +=
62- tokens_cnts[index (num_experts, i - 1 , eid )];
61+ tokens_cnts[index (num_experts, i, threadIdx . x )] +=
62+ tokens_cnts[index (num_experts, i - 1 , threadIdx . x )];
6363 }
6464 }
6565
@@ -83,9 +83,10 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
8383 * For each expert, each thread processes the tokens of the corresponding
8484 * blocks and stores the corresponding expert_id for each block.
8585 */
86- for (int eid = threadIdx .x ; eid < num_experts; eid += blockDim .x ) {
87- for (int i = cumsum[eid]; i < cumsum[eid + 1 ]; i += block_size) {
88- expert_ids[i / block_size] = eid;
86+ if (threadIdx .x < num_experts) {
87+ for (int i = cumsum[threadIdx .x ]; i < cumsum[threadIdx .x + 1 ];
88+ i += block_size) {
89+ expert_ids[i / block_size] = threadIdx .x ;
8990 }
9091 }
9192
@@ -140,11 +141,11 @@ __global__ void moe_align_block_size_global_mem_kernel(
140141 __syncthreads ();
141142
142143 // For each expert we accumulate the token counts from the different threads.
143- for ( int eid = threadIdx .x ; eid < num_experts; eid += blockDim . x ) {
144- tokens_cnts[index (num_experts, 0 , eid )] = 0 ;
144+ if ( threadIdx .x < num_experts) {
145+ tokens_cnts[index (num_experts, 0 , threadIdx . x )] = 0 ;
145146 for (int i = 1 ; i <= blockDim .x ; ++i) {
146- tokens_cnts[index (num_experts, i, eid )] +=
147- tokens_cnts[index (num_experts, i - 1 , eid )];
147+ tokens_cnts[index (num_experts, i, threadIdx . x )] +=
148+ tokens_cnts[index (num_experts, i - 1 , threadIdx . x )];
148149 }
149150 }
150151
@@ -168,9 +169,10 @@ __global__ void moe_align_block_size_global_mem_kernel(
168169 * For each expert, each thread processes the tokens of the corresponding
169170 * blocks and stores the corresponding expert_id for each block.
170171 */
171- for (int eid = threadIdx .x ; eid < num_experts; eid += blockDim .x ) {
172- for (int i = cumsum[eid]; i < cumsum[eid + 1 ]; i += block_size) {
173- expert_ids[i / block_size] = eid;
172+ if (threadIdx .x < num_experts) {
173+ for (int i = cumsum[threadIdx .x ]; i < cumsum[threadIdx .x + 1 ];
174+ i += block_size) {
175+ expert_ids[i / block_size] = threadIdx .x ;
174176 }
175177 }
176178
@@ -221,25 +223,61 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
221223 torch::Tensor experts_ids,
222224 torch::Tensor num_tokens_post_pad) {
223225 const cudaStream_t stream = at::cuda::getCurrentCUDAStream ();
224- VLLM_DISPATCH_INTEGRAL_TYPES (
225- topk_ids.scalar_type (), " moe_align_block_size_kernel" , [&] {
226- // calc needed amount of shared mem for `tokens_cnts` and `cumsum`
227- // tensors
228- const int32_t num_thread = WARP_SIZE;
229- const int32_t shared_mem =
230- ((num_thread + 1 ) * num_experts + (num_experts + 1 )) *
231- sizeof (int32_t );
232-
233- // set dynamic shared mem
234- auto kernel = vllm::moe::moe_align_block_size_kernel<scalar_t >;
235- AT_CUDA_CHECK (VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize (
236- (void *)kernel, shared_mem));
237- kernel<<<1 , num_thread, shared_mem, stream>>> (
238- topk_ids.data_ptr <scalar_t >(), sorted_token_ids.data_ptr <int32_t >(),
239- experts_ids.data_ptr <int32_t >(),
240- num_tokens_post_pad.data_ptr <int32_t >(), num_experts, block_size,
241- topk_ids.numel ());
242- });
226+
227+ // If we have very large number of experts, we can no longer use shared
228+ // memory.
229+ // TODO(simon): the right solution should be calculating the exact right
230+ // amount of shared memory and use that. The num_experts >= 256 is just a
231+ // temporary solution to unblock Deepseek V3.
232+ if (num_experts >= 96 ) {
233+ VLLM_DISPATCH_INTEGRAL_TYPES (
234+ topk_ids.scalar_type (), " moe_align_block_size_global_mem_kernel" , [&] {
235+ // calc needed amount of shared mem for `tokens_cnts` and `cumsum`
236+ // tensors
237+ const int32_t num_thread = max ((int32_t )num_experts, WARP_SIZE);
238+
239+ const int32_t mem_tokens_cnts =
240+ ((num_experts + 1 ) * num_experts) * sizeof (int32_t );
241+ const int32_t mem_cumsum = (num_experts + 1 ) * sizeof (int32_t );
242+ // allocate global memory
243+ int32_t * tokens_cnts;
244+ int32_t * cumsum;
245+ cudaMalloc (&tokens_cnts, mem_tokens_cnts);
246+ cudaMalloc (&cumsum, mem_cumsum);
247+
248+ auto kernel =
249+ vllm::moe::moe_align_block_size_global_mem_kernel<scalar_t >;
250+ kernel<<<1 , num_thread, 0 , stream>>> (
251+ topk_ids.data_ptr <scalar_t >(),
252+ sorted_token_ids.data_ptr <int32_t >(),
253+ experts_ids.data_ptr <int32_t >(),
254+ num_tokens_post_pad.data_ptr <int32_t >(), num_experts, block_size,
255+ topk_ids.numel (), tokens_cnts, cumsum);
256+ cudaFree (tokens_cnts);
257+ cudaFree (cumsum);
258+ });
259+ } else {
260+ VLLM_DISPATCH_INTEGRAL_TYPES (
261+ topk_ids.scalar_type (), " moe_align_block_size_kernel" , [&] {
262+ // calc needed amount of shared mem for `tokens_cnts` and `cumsum`
263+ // tensors
264+ const int32_t num_thread = max ((int32_t )num_experts, WARP_SIZE);
265+ const int32_t shared_mem =
266+ ((num_thread + 1 ) * num_experts + (num_experts + 1 )) *
267+ sizeof (int32_t );
268+
269+ // set dynamic shared mem
270+ auto kernel = vllm::moe::moe_align_block_size_kernel<scalar_t >;
271+ AT_CUDA_CHECK (VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize (
272+ (void *)kernel, shared_mem));
273+ kernel<<<1 , num_thread, shared_mem, stream>>> (
274+ topk_ids.data_ptr <scalar_t >(),
275+ sorted_token_ids.data_ptr <int32_t >(),
276+ experts_ids.data_ptr <int32_t >(),
277+ num_tokens_post_pad.data_ptr <int32_t >(), num_experts, block_size,
278+ topk_ids.numel ());
279+ });
280+ }
243281}
244282
245283void moe_sum (torch::Tensor& input, // [num_tokens, topk, hidden_size]
0 commit comments