Skip to content

Commit 3deb3b1

Browse files
committed
Factor out reduce_rows_f32 from common.cuh
This increases iteration cycle speed by not having to recompile every kernel all the time
1 parent 36d3f00 commit 3deb3b1

File tree

4 files changed

+23
-20
lines changed

4 files changed

+23
-20
lines changed

ggml/src/ggml-cuda/common.cuh

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -412,26 +412,6 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
412412
#endif // FP16_AVAILABLE
413413
}
414414

415-
// Row reduction kernel template - compute sum (norm=false) or mean (norm=true)
416-
template<bool norm>
417-
static __global__ void reduce_rows_f32(const float * x, float * dst, const int ncols) {
418-
const int row = blockIdx.x;
419-
const int col = threadIdx.x;
420-
421-
float sum = 0.0f;
422-
for (int i = col; i < ncols; i += blockDim.x) {
423-
sum += x[row * ncols + i];
424-
}
425-
426-
sum = warp_reduce_sum(sum);
427-
428-
if (col != 0) {
429-
return;
430-
}
431-
432-
dst[row] = norm ? sum / ncols : sum;
433-
}
434-
435415
template<int width = WARP_SIZE>
436416
static __device__ __forceinline__ int warp_reduce_all(int x) {
437417
#ifdef GGML_USE_HIP

ggml/src/ggml-cuda/mean.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include "mean.cuh"
2+
#include "reduce_rows.cuh"
23

34
void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
45
const ggml_tensor * src0 = dst->src[0];

ggml/src/ggml-cuda/reduce_rows.cuh

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
#include "common.cuh"
2+
3+
// Row reduction kernel template - compute sum (norm=false) or mean (norm=true)
4+
template<bool norm>
5+
static __global__ void reduce_rows_f32(const float * x, float * dst, const int ncols) {
6+
const int row = blockIdx.x;
7+
const int col = threadIdx.x;
8+
9+
float sum = 0.0f;
10+
for (int i = col; i < ncols; i += blockDim.x) {
11+
sum += x[row * ncols + i];
12+
}
13+
14+
sum = warp_reduce_sum(sum);
15+
16+
if (col != 0) {
17+
return;
18+
}
19+
20+
dst[row] = norm ? sum / ncols : sum;
21+
}

ggml/src/ggml-cuda/sumrows.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include "sumrows.cuh"
2+
#include "reduce_rows.cuh"
23

34
void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
45
const dim3 block_dims(WARP_SIZE, 1, 1);

0 commit comments

Comments
 (0)