@@ -55,11 +55,11 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
5555 __syncthreads ();
5656
5757 // For each expert we accumulate the token counts from the different threads.
58- if ( threadIdx .x < num_experts) {
59- tokens_cnts[index (num_experts, 0 , threadIdx . x )] = 0 ;
58+ for ( int eid = threadIdx .x ; eid < num_experts; eid += blockDim . x ) {
59+ tokens_cnts[index (num_experts, 0 , eid )] = 0 ;
6060 for (int i = 1 ; i <= blockDim .x ; ++i) {
61- tokens_cnts[index (num_experts, i, threadIdx . x )] +=
62- tokens_cnts[index (num_experts, i - 1 , threadIdx . x )];
61+ tokens_cnts[index (num_experts, i, eid )] +=
62+ tokens_cnts[index (num_experts, i - 1 , eid )];
6363 }
6464 }
6565
@@ -83,10 +83,9 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
8383 * For each expert, each thread processes the tokens of the corresponding
8484 * blocks and stores the corresponding expert_id for each block.
8585 */
86- if (threadIdx .x < num_experts) {
87- for (int i = cumsum[threadIdx .x ]; i < cumsum[threadIdx .x + 1 ];
88- i += block_size) {
89- expert_ids[i / block_size] = threadIdx .x ;
86+ for (int eid = threadIdx .x ; eid < num_experts; eid += blockDim .x ) {
87+ for (int i = cumsum[eid]; i < cumsum[eid + 1 ]; i += block_size) {
88+ expert_ids[i / block_size] = eid;
9089 }
9190 }
9291
@@ -141,7 +140,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
141140 topk_ids.scalar_type (), " moe_align_block_size_kernel" , [&] {
142141 // calc needed amount of shared mem for `tokens_cnts` and `cumsum`
143142 // tensors
144- const int32_t num_thread = max (( int32_t )num_experts, WARP_SIZE) ;
143+ const int32_t num_thread = WARP_SIZE;
145144 const int32_t shared_mem =
146145 ((num_thread + 1 ) * num_experts + (num_experts + 1 )) *
147146 sizeof (int32_t );
0 commit comments