11#include " sumrows.cuh"
22
3- static __global__ void k_sum_rows_f32 (const float * x, float * dst, const int ncols) {
4- const int row = blockIdx .x ;
5- const int col = threadIdx .x ;
6-
7- float sum = 0 .0f ;
8- for (int i = col; i < ncols; i += blockDim .x ) {
9- sum += x[row * ncols + i];
10- }
11-
12- sum = warp_reduce_sum (sum);
13-
14- if (col == 0 ) {
15- dst[row] = sum;
16- }
17- }
18-
19- void sum_rows_f32_cuda (const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
20- const dim3 block_dims (WARP_SIZE, 1 , 1 );
21- const dim3 block_nums (nrows, 1 , 1 );
22- k_sum_rows_f32<<<block_nums, block_dims, 0 , stream>>> (x, dst, ncols);
23- }
24-
253void ggml_cuda_op_sum_rows (ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
264 const ggml_tensor * src0 = dst->src [0 ];
275 const float * src0_d = (const float *)src0->data ;
@@ -35,5 +13,8 @@ void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
3513 const int64_t ncols = src0->ne [0 ];
3614 const int64_t nrows = ggml_nrows (src0);
3715
38- sum_rows_f32_cuda (src0_d, dst_d, ncols, nrows, stream);
16+ const dim3 block_dims (WARP_SIZE, 1 , 1 );
17+ const dim3 block_nums (nrows, 1 , 1 );
18+
19+ reduce_rows_f32</* norm=*/ false ><<<block_nums, block_dims, 0 , stream>>> (src0_d, dst_d, ncols);
3920}
0 commit comments