@@ -377,11 +377,14 @@ inline __device__ T add(T a, T b)
377377
378378#define FINAL_MASK 0xffffffff
379379#define WARP_SIZE 32
380+ #define LOG2_WARP_SIZE 5
381+ #define LANE_ID_MASK 0x1f
380382
381383template <typename T>
382- __inline__ __device__ T warpReduceSum (T val)
384+ __inline__ __device__ T warpReduceSumPartial (T val)
383385{
384- int lane_id = threadIdx .x & 0x1f ;
386+ int lane_id = threadIdx .x & LANE_ID_MASK;
387+ // This function should only be called on the last warp
385388 int warp_size = blockDim .x - (threadIdx .x & ~(WARP_SIZE - 1 ));
386389 unsigned int active_mask = (1U << warp_size) - 1 ;
387390
@@ -395,14 +398,25 @@ __inline__ __device__ T warpReduceSum(T val)
395398 return val;
396399}
397400
401+ template <typename T>
402+ inline __device__ T warpReduceSumFull (T val)
403+ {
404+ #pragma unroll
405+ for (int offset = 16 ; offset > 0 ; offset >>= 1 )
406+ {
407+ val = add<T>(val, __shfl_xor_sync (FINAL_MASK, val, offset, WARP_SIZE));
408+ }
409+ return val;
410+ }
398411inline __device__ float block_reduce_sum (float val)
399412{
400413 __shared__ float smem[WARP_SIZE];
401- int lane_id = threadIdx .x % WARP_SIZE;
402- int warp_id = threadIdx .x / WARP_SIZE;
403- int warp_num = (blockDim .x + WARP_SIZE - 1 ) / WARP_SIZE; // Ceiling division to include partial warps
414+ int lane_id = threadIdx .x & LANE_ID_MASK;
415+ int warp_id = threadIdx .x >> LOG2_WARP_SIZE;
416+ bool has_partial_warp = (blockDim .x % WARP_SIZE) != 0 ;
417+ int warp_num = (blockDim .x + WARP_SIZE - 1 ) >> LOG2_WARP_SIZE; // Ceiling division to include partial warps
404418
405- val = warpReduceSum (val);
419+ val = (has_partial_warp && (warp_id == warp_num - 1 )) ? warpReduceSumPartial (val) : warpReduceSumFull (val);
406420 if (lane_id == 0 )
407421 {
408422 smem[warp_id] = val;
@@ -411,10 +425,10 @@ inline __device__ float block_reduce_sum(float val)
411425 if (warp_id == 0 )
412426 {
413427 val = lane_id < warp_num ? smem[lane_id] : 0 .f ;
414- val = warpReduceSum (val);
428+ val = (has_partial_warp && warp_num == 1 ) ? warpReduceSumPartial (val) : warpReduceSumFull (val);
415429 if (lane_id == 0 )
416430 {
417- val = smem[0 ];
431+ smem[0 ] = val ;
418432 }
419433 }
420434 __syncthreads ();
0 commit comments