Skip to content

Commit 0e5e64c

Browse files
committed
Factor out reduce_rows_f32 from common.cuh
This increases iteration cycle speed by not having to recompile every kernel all the time
1 parent ad4a700 commit 0e5e64c

File tree

4 files changed

+23
-20
lines changed

4 files changed

+23
-20
lines changed

ggml/src/ggml-cuda/common.cuh

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -411,26 +411,6 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
411411
#endif // FP16_AVAILABLE
412412
}
413413

414-
// Row reduction kernel template - compute sum (norm=false) or mean (norm=true)
415-
template<bool norm>
416-
static __global__ void reduce_rows_f32(const float * x, float * dst, const int ncols) {
417-
const int row = blockIdx.x;
418-
const int col = threadIdx.x;
419-
420-
float sum = 0.0f;
421-
for (int i = col; i < ncols; i += blockDim.x) {
422-
sum += x[row * ncols + i];
423-
}
424-
425-
sum = warp_reduce_sum(sum);
426-
427-
if (col != 0) {
428-
return;
429-
}
430-
431-
dst[row] = norm ? sum / ncols : sum;
432-
}
433-
434414
template<int width = WARP_SIZE>
435415
static __device__ __forceinline__ int warp_reduce_all(int x) {
436416
#ifdef GGML_USE_HIP

ggml/src/ggml-cuda/mean.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include "mean.cuh"
2+
#include "reduce_rows.cuh"
23

34
void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
45
const ggml_tensor * src0 = dst->src[0];

ggml/src/ggml-cuda/reduce_rows.cuh

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
#include "common.cuh"
2+
3+
// Row reduction kernel template - compute sum (norm=false) or mean (norm=true)
4+
template<bool norm>
5+
static __global__ void reduce_rows_f32(const float * x, float * dst, const int ncols) {
6+
const int row = blockIdx.x;
7+
const int col = threadIdx.x;
8+
9+
float sum = 0.0f;
10+
for (int i = col; i < ncols; i += blockDim.x) {
11+
sum += x[row * ncols + i];
12+
}
13+
14+
sum = warp_reduce_sum(sum);
15+
16+
if (col != 0) {
17+
return;
18+
}
19+
20+
dst[row] = norm ? sum / ncols : sum;
21+
}

ggml/src/ggml-cuda/sumrows.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include "sumrows.cuh"
2+
#include "reduce_rows.cuh"
23

34
void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
45
const dim3 block_dims(WARP_SIZE, 1, 1);

0 commit comments

Comments
 (0)