@@ -455,3 +455,86 @@ void ggml_cuda_op_fused_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor *
455455 fused_rms_norm_f32_nc_cuda (src0_d, src1_d, dst_d, ne00, src0->ne [1 ], src0->ne [2 ], src0->ne [3 ], s01, s02, s03, eps, stream);
456456 }
457457}
458+
459+ template <int block_size>
460+ static __global__ void fused_add_rms_norm_f32 (const float * a, const float * b, const float * c,
461+ float * dst_add, float * dst, const int ncols, const float eps) {
462+ const int row = blockIdx .x *blockDim .y + threadIdx .y ;
463+ const int tid = threadIdx .x ;
464+
465+ float tmp = 0 .0f ; // partial sum for thread in warp
466+
467+ for (int col = tid; col < ncols; col += block_size) {
468+ const float xi = a[row*ncols + col] + b[row*ncols + col];
469+ tmp += xi * xi;
470+ dst_add[row*ncols + col] = xi;
471+ }
472+
473+ // sum up partial sums
474+ tmp = warp_reduce_sum (tmp);
475+ if (block_size > WARP_SIZE) {
476+ __shared__ float s_sum[32 ];
477+ int warp_id = threadIdx .x / WARP_SIZE;
478+ int lane_id = threadIdx .x % WARP_SIZE;
479+ if (lane_id == 0 ) {
480+ s_sum[warp_id] = tmp;
481+ }
482+ __syncthreads ();
483+ tmp = lane_id < block_size/WARP_SIZE ? s_sum[lane_id] : 0 .0f ;
484+ tmp = warp_reduce_sum (tmp);
485+ }
486+
487+ const float mean = tmp / ncols;
488+ const float scale = rsqrtf (mean + eps);
489+
490+ for (int col = tid; col < ncols; col += block_size) {
491+ dst[row*ncols + col] = scale * c[col] * dst_add[row*ncols + col];
492+ }
493+ }
494+
495+
496+ static void fused_add_rms_norm_f32_cuda (const float * a, const float * b, const float * c, float * dst_add, float * dst,
497+ const int ncols, const int nrows, const float eps, cudaStream_t stream) {
498+ GGML_ASSERT (ncols % WARP_SIZE == 0 );
499+ if (ncols < 1024 ) {
500+ const dim3 block_dims (256 , 1 , 1 );
501+ fused_add_rms_norm_f32<256 ><<<nrows, block_dims, 0 , stream>>> (a, b, c, dst_add, dst, ncols, eps);
502+ } else {
503+ const dim3 block_dims (1024 , 1 , 1 );
504+ fused_add_rms_norm_f32<1024 ><<<nrows, block_dims, 0 , stream>>> (a, b, c, dst_add, dst, ncols, eps);
505+ }
506+ }
507+
508+ void ggml_cuda_op_fused_add_rms_norm (ggml_backend_cuda_context & ctx, ggml_tensor * add, ggml_tensor * dst) {
509+
510+ const ggml_tensor * src0 = dst->src [0 ];
511+ const ggml_tensor * src1 = dst->src [1 ];
512+ // const float * src0_d = (const float *)src0->data;
513+ const float * src1_d = (const float *)src1->data ;
514+ float * dst_d = (float *)dst->data ;
515+ cudaStream_t stream = ctx.stream ();
516+
517+ GGML_ASSERT (add->data == src0->data );
518+ GGML_ASSERT (ggml_is_contiguous (src0));
519+ GGML_ASSERT (ggml_is_contiguous (add->src [0 ]));
520+ GGML_ASSERT (ggml_is_contiguous (add->src [1 ]));
521+ GGML_ASSERT (ggml_are_same_shape (add->src [0 ], add->src [1 ]));
522+ GGML_ASSERT (ggml_are_same_shape (add->src [0 ], src0));
523+ GGML_ASSERT (add->src [0 ]->type == GGML_TYPE_F32);
524+ GGML_ASSERT (add->src [1 ]->type == GGML_TYPE_F32);
525+ GGML_ASSERT (src0->type == GGML_TYPE_F32);
526+ GGML_ASSERT (src1->type == GGML_TYPE_F32);
527+ GGML_ASSERT ( dst->type == GGML_TYPE_F32);
528+ GGML_ASSERT (src0->ne [0 ] == src1->ne [0 ]);
529+ GGML_ASSERT (ggml_nrows (src1) == 1 );
530+
531+ float eps;
532+ memcpy (&eps, dst->op_params , sizeof (float ));
533+
534+ const int64_t ne00 = src0->ne [0 ];
535+
536+ const int64_t nrows = ggml_nrows (src0);
537+ fused_add_rms_norm_f32_cuda ((const float *)add->src [0 ]->data , (const float *)add->src [1 ]->data ,
538+ src1_d, (float *)add->data , dst_d, ne00, nrows, eps, stream);
539+ }
540+
0 commit comments