diff --git a/cpp/tensorrt_llm/kernels/communicationKernels/mnnvlTwoShotAllreduceKernels.cu b/cpp/tensorrt_llm/kernels/communicationKernels/mnnvlTwoShotAllreduceKernels.cu index f439f2b8efd..d8c79485c2d 100644 --- a/cpp/tensorrt_llm/kernels/communicationKernels/mnnvlTwoShotAllreduceKernels.cu +++ b/cpp/tensorrt_llm/kernels/communicationKernels/mnnvlTwoShotAllreduceKernels.cu @@ -377,21 +377,34 @@ inline __device__ T add(T a, T b) #define FINAL_MASK 0xffffffff #define WARP_SIZE 32 +#define LOG2_WARP_SIZE 5 +#define LANE_ID_MASK 0x1f template -__inline__ __device__ T warpReduceSum(T val) +__inline__ __device__ T warpReduceSumPartial(T val) { - // Get the actual number of active threads in this warp - int active_warp_size = min(WARP_SIZE, blockDim.x - (threadIdx.x & ~(WARP_SIZE - 1))); - unsigned int mask = (1U << active_warp_size) - 1; + int lane_id = threadIdx.x & LANE_ID_MASK; + // This function should only be called on the last warp + int warp_size = blockDim.x - (threadIdx.x & ~(WARP_SIZE - 1)); + unsigned int active_mask = (1U << warp_size) - 1; #pragma unroll for (int offset = 16; offset > 0; offset >>= 1) { - if (offset < active_warp_size) - { - val = add(val, __shfl_xor_sync(mask, val, offset, WARP_SIZE)); - } + int target_lane = lane_id ^ offset; + auto tmp = __shfl_xor_sync(active_mask, val, offset, WARP_SIZE); + val = add(val, target_lane < warp_size ? tmp : 0); + } + return val; +} + +template +inline __device__ T warpReduceSumFull(T val) +{ +#pragma unroll + for (int offset = 16; offset > 0; offset >>= 1) + { + val = add(val, __shfl_xor_sync(FINAL_MASK, val, offset, WARP_SIZE)); } return val; } @@ -399,18 +412,28 @@ __inline__ __device__ T warpReduceSum(T val) inline __device__ float block_reduce_sum(float val) { __shared__ float smem[WARP_SIZE]; - int lane_id = threadIdx.x % WARP_SIZE; - int warp_id = threadIdx.x / WARP_SIZE; - int warp_num = (blockDim.x + WARP_SIZE - 1) / WARP_SIZE; // Ceiling division to include partial warps + int lane_id = threadIdx.x & LANE_ID_MASK; + int warp_id = threadIdx.x >> LOG2_WARP_SIZE; + bool has_partial_warp = (blockDim.x % WARP_SIZE) != 0; + int warp_num = (blockDim.x + WARP_SIZE - 1) >> LOG2_WARP_SIZE; // Ceiling division to include partial warps - val = warpReduceSum(val); + val = (has_partial_warp && (warp_id == warp_num - 1)) ? warpReduceSumPartial(val) : warpReduceSumFull(val); if (lane_id == 0) { smem[warp_id] = val; } __syncthreads(); - val = lane_id < warp_num ? smem[lane_id] : 0.f; - val = warpReduceSum(val); + if (warp_id == 0) + { + val = lane_id < warp_num ? smem[lane_id] : 0.f; + val = (has_partial_warp && warp_num == 1) ? warpReduceSumPartial(val) : warpReduceSumFull(val); + if (lane_id == 0) + { + smem[0] = val; + } + } + __syncthreads(); + val = smem[0]; return val; } diff --git a/tests/unittest/_torch/multi_gpu/test_mnnvl_allreduce.py b/tests/unittest/_torch/multi_gpu/test_mnnvl_allreduce.py index 9997f9dbbf5..53692c51921 100644 --- a/tests/unittest/_torch/multi_gpu/test_mnnvl_allreduce.py +++ b/tests/unittest/_torch/multi_gpu/test_mnnvl_allreduce.py @@ -161,7 +161,6 @@ def func(input, residual, norm_weight, eps, enable_fusion): ) -@pytest.mark.skip(reason="https://nvbugs/5597647") @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="needs 2 GPUs to run this test") @pytest.mark.parametrize(