Skip to content

Commit 1524763

Browse files
committed
Fix warp reduce issue.
Signed-off-by: Shiyu Li <[email protected]>
1 parent a81a658 commit 1524763

File tree

1 file changed

+22
-8
lines changed

1 file changed

+22
-8
lines changed

cpp/tensorrt_llm/kernels/communicationKernels/mnnvlTwoShotAllreduceKernels.cu

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -377,11 +377,14 @@ inline __device__ T add(T a, T b)
377377

378378
#define FINAL_MASK 0xffffffff
379379
#define WARP_SIZE 32
380+
#define LOG2_WARP_SIZE 5
381+
#define LANE_ID_MASK 0x1f
380382

381383
template <typename T>
382-
__inline__ __device__ T warpReduceSum(T val)
384+
__inline__ __device__ T warpReduceSumPartial(T val)
383385
{
384-
int lane_id = threadIdx.x & 0x1f;
386+
int lane_id = threadIdx.x & LANE_ID_MASK;
387+
// This function should only be called on the last warp
385388
int warp_size = blockDim.x - (threadIdx.x & ~(WARP_SIZE - 1));
386389
unsigned int active_mask = (1U << warp_size) - 1;
387390

@@ -395,14 +398,25 @@ __inline__ __device__ T warpReduceSum(T val)
395398
return val;
396399
}
397400

401+
template <typename T>
402+
inline __device__ T warpReduceSumFull(T val)
403+
{
404+
#pragma unroll
405+
for (int offset = 16; offset > 0; offset >>= 1)
406+
{
407+
val = add<T>(val, __shfl_xor_sync(FINAL_MASK, val, offset, WARP_SIZE));
408+
}
409+
return val;
410+
}
398411
inline __device__ float block_reduce_sum(float val)
399412
{
400413
__shared__ float smem[WARP_SIZE];
401-
int lane_id = threadIdx.x % WARP_SIZE;
402-
int warp_id = threadIdx.x / WARP_SIZE;
403-
int warp_num = (blockDim.x + WARP_SIZE - 1) / WARP_SIZE; // Ceiling division to include partial warps
414+
int lane_id = threadIdx.x & LANE_ID_MASK;
415+
int warp_id = threadIdx.x >> LOG2_WARP_SIZE;
416+
bool has_partial_warp = (blockDim.x % WARP_SIZE) != 0;
417+
int warp_num = (blockDim.x + WARP_SIZE - 1) >> LOG2_WARP_SIZE; // Ceiling division to include partial warps
404418

405-
val = warpReduceSum(val);
419+
val = (has_partial_warp && (warp_id == warp_num - 1)) ? warpReduceSumPartial(val) : warpReduceSumFull(val);
406420
if (lane_id == 0)
407421
{
408422
smem[warp_id] = val;
@@ -411,10 +425,10 @@ inline __device__ float block_reduce_sum(float val)
411425
if (warp_id == 0)
412426
{
413427
val = lane_id < warp_num ? smem[lane_id] : 0.f;
414-
val = warpReduceSum(val);
428+
val = (has_partial_warp && warp_num == 1) ? warpReduceSumPartial(val) : warpReduceSumFull(val);
415429
if (lane_id == 0)
416430
{
417-
val = smem[0];
431+
smem[0] = val;
418432
}
419433
}
420434
__syncthreads();

0 commit comments

Comments
 (0)