File tree Expand file tree Collapse file tree 1 file changed +16
-3
lines changed Expand file tree Collapse file tree 1 file changed +16
-3
lines changed Original file line number Diff line number Diff line change 22
33// Row reduction kernel template - compute sum (norm=false) or mean (norm=true)
44template <bool norm>
5- static __global__ void reduce_rows_f32 (const float * x, float * dst, const int ncols) {
5+ static __global__ void reduce_rows_f32 (const float * __restrict__ x, float * __restrict__ dst, const int ncols) {
66 const int row = blockIdx .x ;
77 const int col = threadIdx .x ;
88
99 float sum = 0 .0f ;
10- for (int i = col; i < ncols; i += blockDim .x ) {
11- sum += x[row * ncols + i];
10+ const int num_unroll = 24 ;
11+ float temp[num_unroll];
12+ for (int i = col; i < ncols;) {
13+ for (int j = 0 ; j < num_unroll; ++j){
14+ if (i < ncols){
15+ temp[j] = x[row * ncols + i];
16+ }
17+ else {
18+ temp[j] = 0 ;
19+ }
20+ i += blockDim .x ;
21+ }
22+ for (int j = 0 ; j < num_unroll; ++j){
23+ sum += temp[j];
24+ }
1225 }
1326
1427 sum = warp_reduce_sum (sum);
You can’t perform that action at this time.
0 commit comments