Skip to content

Commit b414ae9

Browse files
authored
Always use 64 as the block size of moe_align kernel to avoid lds out of limit (ROCm#303)
* always use 64 as the block size to avoid lds out of limit * lint
1 parent ccdb5b8 commit b414ae9

File tree

1 file changed

+8
-9
lines changed

1 file changed

+8
-9
lines changed

csrc/moe/moe_align_sum_kernels.cu

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -55,11 +55,11 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
5555
__syncthreads();
5656

5757
// For each expert we accumulate the token counts from the different threads.
58-
if (threadIdx.x < num_experts) {
59-
tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0;
58+
for (int eid = threadIdx.x; eid < num_experts; eid += blockDim.x) {
59+
tokens_cnts[index(num_experts, 0, eid)] = 0;
6060
for (int i = 1; i <= blockDim.x; ++i) {
61-
tokens_cnts[index(num_experts, i, threadIdx.x)] +=
62-
tokens_cnts[index(num_experts, i - 1, threadIdx.x)];
61+
tokens_cnts[index(num_experts, i, eid)] +=
62+
tokens_cnts[index(num_experts, i - 1, eid)];
6363
}
6464
}
6565

@@ -83,10 +83,9 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
8383
* For each expert, each thread processes the tokens of the corresponding
8484
* blocks and stores the corresponding expert_id for each block.
8585
*/
86-
if (threadIdx.x < num_experts) {
87-
for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1];
88-
i += block_size) {
89-
expert_ids[i / block_size] = threadIdx.x;
86+
for (int eid = threadIdx.x; eid < num_experts; eid += blockDim.x) {
87+
for (int i = cumsum[eid]; i < cumsum[eid + 1]; i += block_size) {
88+
expert_ids[i / block_size] = eid;
9089
}
9190
}
9291

@@ -141,7 +140,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
141140
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
142141
// calc needed amount of shared mem for `tokens_cnts` and `cumsum`
143142
// tensors
144-
const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE);
143+
const int32_t num_thread = WARP_SIZE;
145144
const int32_t shared_mem =
146145
((num_thread + 1) * num_experts + (num_experts + 1)) *
147146
sizeof(int32_t);

0 commit comments

Comments
 (0)