@@ -131,40 +131,6 @@ static __global__ void rms_norm_f32(const float * x, float * dst, const int ncol
131131 }
132132}
133133
134- template <int block_size>
135- static __global__ void fused_rms_norm_f32 (const float * x, const float * y, float * dst, const int ncols, const float eps) {
136- const int row = blockIdx .x *blockDim .y + threadIdx .y ;
137- const int tid = threadIdx .x ;
138-
139- float tmp = 0 .0f ; // partial sum for thread in warp
140-
141- for (int col = tid; col < ncols; col += block_size) {
142- const float xi = x[row*ncols + col];
143- tmp += xi * xi;
144- }
145-
146- // sum up partial sums
147- tmp = warp_reduce_sum (tmp);
148- if (block_size > WARP_SIZE) {
149- __shared__ float s_sum[32 ];
150- int warp_id = threadIdx .x / WARP_SIZE;
151- int lane_id = threadIdx .x % WARP_SIZE;
152- if (lane_id == 0 ) {
153- s_sum[warp_id] = tmp;
154- }
155- __syncthreads ();
156- tmp = s_sum[lane_id];
157- tmp = warp_reduce_sum (tmp);
158- }
159-
160- const float mean = tmp / ncols;
161- const float scale = rsqrtf (mean + eps);
162-
163- for (int col = tid; col < ncols; col += block_size) {
164- dst[row*ncols + col] = scale * y[col] * x[row*ncols + col];
165- }
166- }
167-
168134static void norm_f32_cuda (const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
169135 GGML_ASSERT (ncols % WARP_SIZE == 0 );
170136 if (ncols < 1024 ) {
@@ -197,18 +163,6 @@ static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, con
197163 }
198164}
199165
200- static void fused_rms_norm_f32_cuda (const float * x, const float * y, float * dst,
201- const int ncols, const int nrows, const float eps, cudaStream_t stream) {
202- GGML_ASSERT (ncols % WARP_SIZE == 0 );
203- if (ncols < 1024 ) {
204- const dim3 block_dims (WARP_SIZE, 1 , 1 );
205- fused_rms_norm_f32<WARP_SIZE><<<nrows, block_dims, 0 , stream>>> (x, y, dst, ncols, eps);
206- } else {
207- const dim3 block_dims (1024 , 1 , 1 );
208- fused_rms_norm_f32<1024 ><<<nrows, block_dims, 0 , stream>>> (x, y, dst, ncols, eps);
209- }
210- }
211-
212166void ggml_cuda_op_norm (ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
213167 const ggml_tensor * src0 = dst->src [0 ];
214168 const float * src0_d = (const float *)src0->data ;
@@ -268,32 +222,3 @@ void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
268222
269223 rms_norm_f32_cuda (src0_d, dst_d, ne00, nrows, eps, stream);
270224}
271-
272- void ggml_cuda_op_fused_rms_norm (ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
273- if (!dst->src [1 ]) {
274- ggml_cuda_op_rms_norm (ctx, dst);
275- return ;
276- }
277- const ggml_tensor * src0 = dst->src [0 ];
278- const ggml_tensor * src1 = dst->src [1 ];
279- const float * src0_d = (const float *)src0->data ;
280- const float * src1_d = (const float *)src1->data ;
281- float * dst_d = (float *)dst->data ;
282- cudaStream_t stream = ctx.stream ();
283-
284- GGML_ASSERT (ggml_is_contiguous (src0));
285-
286- GGML_ASSERT (src0->type == GGML_TYPE_F32);
287- GGML_ASSERT (src1->type == GGML_TYPE_F32);
288- GGML_ASSERT ( dst->type == GGML_TYPE_F32);
289- GGML_ASSERT (src0->ne [0 ] == src1->ne [0 ]);
290- GGML_ASSERT (ggml_nrows (src1) == 1 );
291-
292- const int64_t ne00 = src0->ne [0 ];
293- const int64_t nrows = ggml_nrows (src0);
294-
295- float eps;
296- memcpy (&eps, dst->op_params , sizeof (float ));
297-
298- fused_rms_norm_f32_cuda (src0_d, src1_d, dst_d, ne00, nrows, eps, stream);
299- }
0 commit comments