diff --git a/csrc/all_to_all/intranode_dispatch.cu b/csrc/all_to_all/intranode_dispatch.cu index 402773b..ac72f8f 100644 --- a/csrc/all_to_all/intranode_dispatch.cu +++ b/csrc/all_to_all/intranode_dispatch.cu @@ -203,8 +203,7 @@ __global__ __launch_bounds__(NUM_WARPS * 32, 1) void dispatchKernel( unsigned firstGroup = blockIdx.x * expertsPerBlock; unsigned lastGroup = std::min(firstGroup + expertsPerBlock, numExpertsAndRanks); - for (unsigned group = firstGroup + threadIdx.x; group < lastGroup; - group += gridDim.x * expertsPerBlock) { + for (unsigned group = firstGroup + threadIdx.x; group < lastGroup; group += blockDim.x) { const uint32_t srcRank = group / numLocalExperts; const uint32_t srcLocalExpert = group % numLocalExperts;