Skip to content

Commit 7c7413e

Browse files
committed
Add CUB-based implementation for GGML_OP_MEAN
Currently this branch is only executed for nrows==1
1 parent 4a1c5bc commit 7c7413e

File tree

3 files changed

+35
-8
lines changed

3 files changed

+35
-8
lines changed

ggml/src/ggml-cuda/common.cuh

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,10 @@
8787
#define GGML_CUDA_CC_IS_QY2(cc) (cc >= GGML_CUDA_CC_QY2 && cc < GGML_CUDA_CC_NG)
8888
#define GGML_CUDA_CC_IS_NG(cc) (cc >= GGML_CUDA_CC_NG)
8989

90+
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070
91+
# define USE_CUB
92+
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070
93+
9094
#ifdef __CUDA_ARCH_LIST__
9195
constexpr bool ggml_cuda_has_arch_impl(int) {
9296
return false;

ggml/src/ggml-cuda/mean.cu

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,15 @@
11
#include "mean.cuh"
22
#include "reduce_rows.cuh"
33

4+
#ifdef USE_CUB
5+
# include <cub/cub.cuh>
6+
using namespace cub;
7+
#endif // USE_CUB
8+
9+
template <typename T> __global__ void divide_by_count(T * result, size_t count) {
10+
*result /= static_cast<T>(count);
11+
}
12+
413
void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
514
const ggml_tensor * src0 = dst->src[0];
615
const float * src0_d = (const float *) src0->data;
@@ -14,6 +23,24 @@ void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
1423
const int64_t ncols = src0->ne[0];
1524
const int64_t nrows = ggml_nrows(src0);
1625

26+
// Special case for reducing vectors
27+
#ifdef USE_CUB
28+
if (nrows == 1) {
29+
// Single row - use device-wide reduction
30+
size_t tmp_size = 0;
31+
ggml_cuda_pool & pool = ctx.pool();
32+
33+
DeviceReduce::Sum(nullptr, tmp_size, src0_d, dst_d, ncols, stream);
34+
35+
ggml_cuda_pool_alloc<uint8_t> tmp_alloc(pool, tmp_size);
36+
DeviceReduce::Sum(tmp_alloc.ptr, tmp_size, src0_d, dst_d, ncols, stream);
37+
38+
// Divide by ncols
39+
divide_by_count<float><<<1, 1, 0, stream>>>(dst_d, ncols);
40+
return;
41+
}
42+
#endif
43+
1744
const dim3 block_nums(nrows, 1, 1);
1845

1946
const int id = ggml_cuda_get_device();

ggml/src/ggml-cuda/sum.cu

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,10 @@
1-
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070
2-
#define USE_CUB
3-
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070
1+
#include "sum.cuh"
2+
#include "sumrows.cuh"
43

54
#ifdef USE_CUB
6-
#include <cub/cub.cuh>
5+
# include <cub/cub.cuh>
76
using namespace cub;
8-
#endif // USE_CUB
9-
10-
#include "sumrows.cuh"
11-
#include "sum.cuh"
7+
#endif // USE_CUB
128

139
#include <cstdint>
1410

0 commit comments

Comments
 (0)