Skip to content

Commit c270ffe

Browse files
committed
Hide memory-latency by loop unrolling in reduce_rows_f32
1 parent 3deb3b1 commit c270ffe

File tree

1 file changed

+16
-3
lines changed

1 file changed

+16
-3
lines changed

ggml/src/ggml-cuda/reduce_rows.cuh

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,26 @@
22

33
// Row reduction kernel template - compute sum (norm=false) or mean (norm=true)
44
template<bool norm>
5-
static __global__ void reduce_rows_f32(const float * x, float * dst, const int ncols) {
5+
static __global__ void reduce_rows_f32(const float * __restrict__ x, float * __restrict__ dst, const int ncols) {
66
const int row = blockIdx.x;
77
const int col = threadIdx.x;
88

99
float sum = 0.0f;
10-
for (int i = col; i < ncols; i += blockDim.x) {
11-
sum += x[row * ncols + i];
10+
const int num_unroll = 24;
11+
float temp[num_unroll];
12+
for (int i = col; i < ncols;) {
13+
for (int j = 0; j < num_unroll; ++j){
14+
if (i < ncols){
15+
temp[j] = x[row * ncols + i];
16+
}
17+
else {
18+
temp[j] = 0;
19+
}
20+
i += blockDim.x;
21+
}
22+
for (int j = 0; j < num_unroll; ++j){
23+
sum += temp[j];
24+
}
1225
}
1326

1427
sum = warp_reduce_sum(sum);

0 commit comments

Comments
 (0)