Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions ggml/src/ggml-cuda/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,26 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
#endif // FP16_AVAILABLE
}

// Row reduction kernel template - compute sum (norm=false) or mean (norm=true)
template<bool norm>
static __global__ void reduce_rows_f32(const float * x, float * dst, const int ncols) {
const int row = blockIdx.x;
const int col = threadIdx.x;

float sum = 0.0f;
for (int i = col; i < ncols; i += blockDim.x) {
sum += x[row * ncols + i];
}

sum = warp_reduce_sum(sum);

if (col != 0) {
return;
}

dst[row] = norm ? sum / ncols : sum;
}

template<int width = WARP_SIZE>
static __device__ __forceinline__ float warp_reduce_max(float x) {
#pragma unroll
Expand Down
5 changes: 5 additions & 0 deletions ggml/src/ggml-cuda/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
#include "ggml-cuda/ssm-scan.cuh"
#include "ggml-cuda/sum.cuh"
#include "ggml-cuda/sumrows.cuh"
#include "ggml-cuda/mean.cuh"
#include "ggml-cuda/tsembd.cuh"
#include "ggml-cuda/unary.cuh"
#include "ggml-cuda/upscale.cuh"
Expand Down Expand Up @@ -2357,6 +2358,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_SUM_ROWS:
ggml_cuda_op_sum_rows(ctx, dst);
break;
case GGML_OP_MEAN:
ggml_cuda_op_mean(ctx, dst);
break;
case GGML_OP_SSM_CONV:
ggml_cuda_op_ssm_conv(ctx, dst);
break;
Expand Down Expand Up @@ -3260,6 +3264,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_POOL_2D:
case GGML_OP_SUM:
case GGML_OP_SUM_ROWS:
case GGML_OP_MEAN:
case GGML_OP_ARGSORT:
case GGML_OP_ACC:
return true;
Expand Down
19 changes: 19 additions & 0 deletions ggml/src/ggml-cuda/mean.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#include "mean.cuh"

void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *) src0->data;
float * dst_d = (float *) dst->data;
cudaStream_t stream = ctx.stream();

GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_ASSERT(ggml_is_contiguous(src0));

const int64_t ncols = src0->ne[0];
const int64_t nrows = ggml_nrows(src0);

const dim3 block_dims(WARP_SIZE, 1, 1);
const dim3 block_nums(nrows, 1, 1);
reduce_rows_f32</*norm*/ true><<<block_nums, block_dims, 0, stream>>>(src0_d, dst_d, ncols);
}
3 changes: 3 additions & 0 deletions ggml/src/ggml-cuda/mean.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#include "common.cuh"

void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
23 changes: 5 additions & 18 deletions ggml/src/ggml-cuda/sumrows.cu
Original file line number Diff line number Diff line change
@@ -1,25 +1,9 @@
#include "sumrows.cuh"

static __global__ void k_sum_rows_f32(const float * x, float * dst, const int ncols) {
const int row = blockIdx.x;
const int col = threadIdx.x;

float sum = 0.0f;
for (int i = col; i < ncols; i += blockDim.x) {
sum += x[row * ncols + i];
}

sum = warp_reduce_sum(sum);

if (col == 0) {
dst[row] = sum;
}
}

void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
const dim3 block_dims(WARP_SIZE, 1, 1);
const dim3 block_nums(nrows, 1, 1);
k_sum_rows_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
reduce_rows_f32</*norm*/false><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
}

void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
Expand All @@ -35,5 +19,8 @@ void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const int64_t ncols = src0->ne[0];
const int64_t nrows = ggml_nrows(src0);

sum_rows_f32_cuda(src0_d, dst_d, ncols, nrows, stream);
const dim3 block_dims(WARP_SIZE, 1, 1);
const dim3 block_nums(nrows, 1, 1);

reduce_rows_f32</*norm=*/false><<<block_nums, block_dims, 0, stream>>>(src0_d, dst_d, ncols);
}
1 change: 0 additions & 1 deletion ggml/src/ggml-cuda/sumrows.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
#include "common.cuh"

void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream);

void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
2 changes: 2 additions & 0 deletions tests/test-backend-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4652,6 +4652,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {

test_cases.emplace_back(new test_conv_transpose_2d({256, 256, 256, 1}, {3, 3, 16, 256}, 1));

test_cases.emplace_back(new test_mean(GGML_TYPE_F32, {256, 256, 3, 1}));

return test_cases;
}

Expand Down