@@ -377,40 +377,63 @@ inline __device__ T add(T a, T b)
377377
378378#define FINAL_MASK 0xffffffff
379379#define WARP_SIZE 32
380+ #define LOG2_WARP_SIZE 5
381+ #define LANE_ID_MASK 0x1f
380382
381383template <typename T>
382- __inline__ __device__ T warpReduceSum (T val)
384+ __inline__ __device__ T warpReduceSumPartial (T val)
383385{
384- // Get the actual number of active threads in this warp
385- int active_warp_size = min (WARP_SIZE, blockDim .x - (threadIdx .x & ~(WARP_SIZE - 1 )));
386- unsigned int mask = (1U << active_warp_size) - 1 ;
386+ int lane_id = threadIdx .x & LANE_ID_MASK;
387+ // This function should only be called on the last warp
388+ int warp_size = blockDim .x - (threadIdx .x & ~(WARP_SIZE - 1 ));
389+ unsigned int active_mask = (1U << warp_size) - 1 ;
387390
388391#pragma unroll
389392 for (int offset = 16 ; offset > 0 ; offset >>= 1 )
390393 {
391- if (offset < active_warp_size)
392- {
393- val = add<T>(val, __shfl_xor_sync (mask, val, offset, WARP_SIZE));
394- }
394+ int target_lane = lane_id ^ offset;
395+ auto tmp = __shfl_xor_sync (active_mask, val, offset, WARP_SIZE);
396+ val = add<T>(val, target_lane < warp_size ? tmp : 0 );
397+ }
398+ return val;
399+ }
400+
401+ template <typename T>
402+ inline __device__ T warpReduceSumFull (T val)
403+ {
404+ #pragma unroll
405+ for (int offset = 16 ; offset > 0 ; offset >>= 1 )
406+ {
407+ val = add<T>(val, __shfl_xor_sync (FINAL_MASK, val, offset, WARP_SIZE));
395408 }
396409 return val;
397410}
398411
399412inline __device__ float block_reduce_sum (float val)
400413{
401414 __shared__ float smem[WARP_SIZE];
402- int lane_id = threadIdx .x % WARP_SIZE;
403- int warp_id = threadIdx .x / WARP_SIZE;
404- int warp_num = (blockDim .x + WARP_SIZE - 1 ) / WARP_SIZE; // Ceiling division to include partial warps
415+ int lane_id = threadIdx .x & LANE_ID_MASK;
416+ int warp_id = threadIdx .x >> LOG2_WARP_SIZE;
417+ bool has_partial_warp = (blockDim .x % WARP_SIZE) != 0 ;
418+ int warp_num = (blockDim .x + WARP_SIZE - 1 ) >> LOG2_WARP_SIZE; // Ceiling division to include partial warps
405419
406- val = warpReduceSum (val);
420+ val = (has_partial_warp && (warp_id == warp_num - 1 )) ? warpReduceSumPartial (val) : warpReduceSumFull (val);
407421 if (lane_id == 0 )
408422 {
409423 smem[warp_id] = val;
410424 }
411425 __syncthreads ();
412- val = lane_id < warp_num ? smem[lane_id] : 0 .f ;
413- val = warpReduceSum (val);
426+ if (warp_id == 0 )
427+ {
428+ val = lane_id < warp_num ? smem[lane_id] : 0 .f ;
429+ val = (has_partial_warp && warp_num == 1 ) ? warpReduceSumPartial (val) : warpReduceSumFull (val);
430+ if (lane_id == 0 )
431+ {
432+ smem[0 ] = val;
433+ }
434+ }
435+ __syncthreads ();
436+ val = smem[0 ];
414437
415438 return val;
416439}
0 commit comments