@@ -131,6 +131,40 @@ static __global__ void rms_norm_f32(const float * x, float * dst, const int ncol
131131 }
132132}
133133
134+ template <int block_size>
135+ static __global__ void fused_rms_norm_f32 (const float * x, const float * y, float * dst, const int ncols, const float eps) {
136+ const int row = blockIdx .x *blockDim .y + threadIdx .y ;
137+ const int tid = threadIdx .x ;
138+
139+ float tmp = 0 .0f ; // partial sum for thread in warp
140+
141+ for (int col = tid; col < ncols; col += block_size) {
142+ const float xi = x[row*ncols + col];
143+ tmp += xi * xi;
144+ }
145+
146+ // sum up partial sums
147+ tmp = warp_reduce_sum (tmp);
148+ if (block_size > WARP_SIZE) {
149+ __shared__ float s_sum[32 ];
150+ int warp_id = threadIdx .x / WARP_SIZE;
151+ int lane_id = threadIdx .x % WARP_SIZE;
152+ if (lane_id == 0 ) {
153+ s_sum[warp_id] = tmp;
154+ }
155+ __syncthreads ();
156+ tmp = s_sum[lane_id];
157+ tmp = warp_reduce_sum (tmp);
158+ }
159+
160+ const float mean = tmp / ncols;
161+ const float scale = rsqrtf (mean + eps);
162+
163+ for (int col = tid; col < ncols; col += block_size) {
164+ dst[row*ncols + col] = scale * y[col] * x[row*ncols + col];
165+ }
166+ }
167+
134168static void norm_f32_cuda (const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
135169 GGML_ASSERT (ncols % WARP_SIZE == 0 );
136170 if (ncols < 1024 ) {
@@ -163,6 +197,18 @@ static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, con
163197 }
164198}
165199
200+ static void fused_rms_norm_f32_cuda (const float * x, const float * y, float * dst,
201+ const int ncols, const int nrows, const float eps, cudaStream_t stream) {
202+ GGML_ASSERT (ncols % WARP_SIZE == 0 );
203+ if (ncols < 1024 ) {
204+ const dim3 block_dims (WARP_SIZE, 1 , 1 );
205+ fused_rms_norm_f32<WARP_SIZE><<<nrows, block_dims, 0 , stream>>> (x, y, dst, ncols, eps);
206+ } else {
207+ const dim3 block_dims (1024 , 1 , 1 );
208+ fused_rms_norm_f32<1024 ><<<nrows, block_dims, 0 , stream>>> (x, y, dst, ncols, eps);
209+ }
210+ }
211+
166212void ggml_cuda_op_norm (ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
167213 const ggml_tensor * src0 = dst->src [0 ];
168214 const float * src0_d = (const float *)src0->data ;
@@ -222,3 +268,32 @@ void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
222268
223269 rms_norm_f32_cuda (src0_d, dst_d, ne00, nrows, eps, stream);
224270}
271+
272+ void ggml_cuda_op_fused_rms_norm (ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
273+ if (!dst->src [1 ]) {
274+ ggml_cuda_op_rms_norm (ctx, dst);
275+ return ;
276+ }
277+ const ggml_tensor * src0 = dst->src [0 ];
278+ const ggml_tensor * src1 = dst->src [1 ];
279+ const float * src0_d = (const float *)src0->data ;
280+ const float * src1_d = (const float *)src1->data ;
281+ float * dst_d = (float *)dst->data ;
282+ cudaStream_t stream = ctx.stream ();
283+
284+ GGML_ASSERT (ggml_is_contiguous (src0));
285+
286+ GGML_ASSERT (src0->type == GGML_TYPE_F32);
287+ GGML_ASSERT (src1->type == GGML_TYPE_F32);
288+ GGML_ASSERT ( dst->type == GGML_TYPE_F32);
289+ GGML_ASSERT (src0->ne [0 ] == src1->ne [0 ]);
290+ GGML_ASSERT (ggml_nrows (src1) == 1 );
291+
292+ const int64_t ne00 = src0->ne [0 ];
293+ const int64_t nrows = ggml_nrows (src0);
294+
295+ float eps;
296+ memcpy (&eps, dst->op_params , sizeof (float ));
297+
298+ fused_rms_norm_f32_cuda (src0_d, src1_d, dst_d, ne00, nrows, eps, stream);
299+ }
0 commit comments