11#include " sumrows.cuh"
22
3- static __global__ void k_sum_rows_f32 (const float * x, float * dst, const int ncols) {
4- const int row = blockIdx .x ;
5- const int col = threadIdx .x ;
6-
7- float sum = 0 .0f ;
8- for (int i = col; i < ncols; i += blockDim .x ) {
9- sum += x[row * ncols + i];
10- }
11-
12- sum = warp_reduce_sum (sum);
13-
14- if (col == 0 ) {
15- dst[row] = sum;
16- }
17- }
18-
193void sum_rows_f32_cuda (const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
204 const dim3 block_dims (WARP_SIZE, 1 , 1 );
215 const dim3 block_nums (nrows, 1 , 1 );
22- k_sum_rows_f32 <<<block_nums, block_dims, 0 , stream>>> (x, dst, ncols);
6+ reduce_rows_f32< /* norm */ false > <<<block_nums, block_dims, 0 , stream>>> (x, dst, ncols);
237}
248
259void ggml_cuda_op_sum_rows (ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
@@ -35,5 +19,8 @@ void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
3519 const int64_t ncols = src0->ne [0 ];
3620 const int64_t nrows = ggml_nrows (src0);
3721
38- sum_rows_f32_cuda (src0_d, dst_d, ncols, nrows, stream);
22+ const dim3 block_dims (WARP_SIZE, 1 , 1 );
23+ const dim3 block_nums (nrows, 1 , 1 );
24+
25+ reduce_rows_f32</* norm=*/ false ><<<block_nums, block_dims, 0 , stream>>> (src0_d, dst_d, ncols);
3926}
0 commit comments