Skip to content

Commit ece608a

Browse files
committed
Further optimizations to reduce_rows_f32
1. Increase threadblock size to better hide latency of memory requests. As a consequence of bigger threadblocks, do 2-step summation, using shared memory to communicate results between invocations 2. Use sum_temp array to reduce waits on sum 3. Adjust num_unroll to reflext bigger threadblock 4. Improve default block_dims, increase support for more block_dims
1 parent c270ffe commit ece608a

File tree

3 files changed

+25
-5
lines changed

3 files changed

+25
-5
lines changed

ggml/src/ggml-cuda/mean.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
1414
const int64_t ncols = src0->ne[0];
1515
const int64_t nrows = ggml_nrows(src0);
1616

17-
const dim3 block_dims(WARP_SIZE, 1, 1);
17+
constexpr dim3 block_dims(512, 1, 1);
1818
const dim3 block_nums(nrows, 1, 1);
1919
reduce_rows_f32</*norm*/ true><<<block_nums, block_dims, 0, stream>>>(src0_d, dst_d, ncols);
2020
}

ggml/src/ggml-cuda/reduce_rows.cuh

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,9 @@ static __global__ void reduce_rows_f32(const float * __restrict__ x, float * __r
77
const int col = threadIdx.x;
88

99
float sum = 0.0f;
10-
const int num_unroll = 24;
10+
const int num_unroll = 8;
1111
float temp[num_unroll];
12+
float sum_temp[num_unroll] = {0.0f};
1213
for (int i = col; i < ncols;) {
1314
for (int j = 0; j < num_unroll; ++j){
1415
if (i < ncols){
@@ -20,11 +21,30 @@ static __global__ void reduce_rows_f32(const float * __restrict__ x, float * __r
2021
i += blockDim.x;
2122
}
2223
for (int j = 0; j < num_unroll; ++j){
23-
sum += temp[j];
24+
sum_temp[j] += temp[j];
2425
}
2526
}
27+
for (int j = 0; j < num_unroll; ++j){
28+
sum += sum_temp[j];
29+
}
2630

31+
// sum up partial sums
2732
sum = warp_reduce_sum(sum);
33+
if (blockDim.x > WARP_SIZE) {
34+
assert((blockDim.x <= 1024) && (blockDim.x % WARP_SIZE) == 0);
35+
__shared__ float s_sum[32];
36+
const int warp_id = threadIdx.x / WARP_SIZE;
37+
const int lane_id = threadIdx.x % WARP_SIZE;
38+
if (lane_id == 0) {
39+
s_sum[warp_id] = sum;
40+
}
41+
__syncthreads();
42+
sum = 0.0f;
43+
if (lane_id < (blockDim.x / WARP_SIZE)) {
44+
sum = s_sum[lane_id];
45+
}
46+
sum = warp_reduce_sum(sum);
47+
}
2848

2949
if (col != 0) {
3050
return;

ggml/src/ggml-cuda/sumrows.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
#include "reduce_rows.cuh"
33

44
void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
5-
const dim3 block_dims(WARP_SIZE, 1, 1);
5+
constexpr dim3 block_dims(512, 1, 1);
66
const dim3 block_nums(nrows, 1, 1);
77
reduce_rows_f32</*norm*/false><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
88
}
@@ -20,7 +20,7 @@ void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
2020
const int64_t ncols = src0->ne[0];
2121
const int64_t nrows = ggml_nrows(src0);
2222

23-
const dim3 block_dims(WARP_SIZE, 1, 1);
23+
constexpr dim3 block_dims(512, 1, 1);
2424
const dim3 block_nums(nrows, 1, 1);
2525

2626
reduce_rows_f32</*norm=*/false><<<block_nums, block_dims, 0, stream>>>(src0_d, dst_d, ncols);

0 commit comments

Comments
 (0)