diff --git a/csrc/all_to_all/internode_dispatch.cu b/csrc/all_to_all/internode_dispatch.cu index 5956b90..36e1266 100644 --- a/csrc/all_to_all/internode_dispatch.cu +++ b/csrc/all_to_all/internode_dispatch.cu @@ -212,11 +212,11 @@ __global__ __launch_bounds__(NUM_WARPS * 32, 1) void dispatchKernel( for (unsigned i = threadIdx.x; i < numTokens; i += blockDim.x) { std::byte *xTokenBuffer = xBufferOut + (group * maxNumTokens + i) * tokenStride; uint32_t token = tokenStart + i; - sourceIndex[token] = *((uint32_t *)(xTokenBuffer + tokenDim)); + sourceIndex[token] = i; sourceExpert[token] = expert; sourceOffset[token] = expertStart + i; sourceGroup[token] = dp; - sourceToken[token] = i; + sourceToken[token] = *((uint32_t *)(xTokenBuffer + tokenDim)); } } @@ -228,7 +228,7 @@ __global__ __launch_bounds__(NUM_WARPS * 32, 1) void dispatchKernel( auto expert = sourceExpert[i]; auto group = expert * numDPGroups + sourceGroup[i]; - std::byte *xTokenBuffer = xBufferOut + (group * maxNumTokens + sourceToken[i]) * tokenStride; + std::byte *xTokenBuffer = xBufferOut + (group * maxNumTokens + sourceIndex[i]) * tokenStride; std::byte *dstXExpert = expertX + expert * expertXStrideRow; float *dstXScaleExpert = expertXScale + expert * expertXScaleStrideCol;