11#include " mean.cuh"
22#include " reduce_rows.cuh"
33
4+ #ifdef USE_CUB
5+ # include < cub/cub.cuh>
6+ using namespace cub ;
7+ #endif // USE_CUB
8+
9+ template <typename T> __global__ void divide_by_count (T * result, size_t count) {
10+ *result /= static_cast <T>(count);
11+ }
12+
413void ggml_cuda_op_mean (ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
514 const ggml_tensor * src0 = dst->src [0 ];
615 const float * src0_d = (const float *) src0->data ;
@@ -14,6 +23,24 @@ void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
1423 const int64_t ncols = src0->ne [0 ];
1524 const int64_t nrows = ggml_nrows (src0);
1625
26+ // Special case for reducing vectors
27+ #ifdef USE_CUB
28+ if (nrows == 1 ) {
29+ // Single row - use device-wide reduction
30+ size_t tmp_size = 0 ;
31+ ggml_cuda_pool & pool = ctx.pool ();
32+
33+ DeviceReduce::Sum (nullptr , tmp_size, src0_d, dst_d, ncols, stream);
34+
35+ ggml_cuda_pool_alloc<uint8_t > tmp_alloc (pool, tmp_size);
36+ DeviceReduce::Sum (tmp_alloc.ptr , tmp_size, src0_d, dst_d, ncols, stream);
37+
38+ // Divide by ncols
39+ divide_by_count<float ><<<1 , 1 , 0 , stream>>> (dst_d, ncols);
40+ return ;
41+ }
42+ #endif
43+
1744 const dim3 block_nums (nrows, 1 , 1 );
1845
1946 const int id = ggml_cuda_get_device ();
0 commit comments