@@ -215,40 +215,6 @@ static __global__ void rms_norm_back_f32(
215215 }
216216}
217217
218- template <int block_size>
219- static __global__ void fused_rms_norm_f32 (const float * x, const float * y, float * dst, const int ncols, const float eps) {
220- const int row = blockIdx .x *blockDim .y + threadIdx .y ;
221- const int tid = threadIdx .x ;
222-
223- float tmp = 0 .0f ; // partial sum for thread in warp
224-
225- for (int col = tid; col < ncols; col += block_size) {
226- const float xi = x[row*ncols + col];
227- tmp += xi * xi;
228- }
229-
230- // sum up partial sums
231- tmp = warp_reduce_sum (tmp);
232- if (block_size > WARP_SIZE) {
233- __shared__ float s_sum[32 ];
234- int warp_id = threadIdx .x / WARP_SIZE;
235- int lane_id = threadIdx .x % WARP_SIZE;
236- if (lane_id == 0 ) {
237- s_sum[warp_id] = tmp;
238- }
239- __syncthreads ();
240- tmp = s_sum[lane_id];
241- tmp = warp_reduce_sum (tmp);
242- }
243-
244- const float mean = tmp / ncols;
245- const float scale = rsqrtf (mean + eps);
246-
247- for (int col = tid; col < ncols; col += block_size) {
248- dst[row*ncols + col] = scale * y[col] * x[row*ncols + col];
249- }
250- }
251-
252218// template <int block_size>
253219// static __global__ void l2_norm_f32(const float * x, float * dst, const int ncols, const float eps) {
254220// const int row = blockIdx.x*blockDim.y + threadIdx.y;
@@ -395,19 +361,6 @@ static void rms_norm_back_f32_cuda(const float * grad, const float * xf, float *
395361 }
396362}
397363
398-
399- static void fused_rms_norm_f32_cuda (const float * x, const float * y, float * dst,
400- const int ncols, const int nrows, const float eps, cudaStream_t stream) {
401- GGML_ASSERT (ncols % WARP_SIZE == 0 );
402- if (ncols < 1024 ) {
403- const dim3 block_dims (WARP_SIZE, 1 , 1 );
404- fused_rms_norm_f32<WARP_SIZE><<<nrows, block_dims, 0 , stream>>> (x, y, dst, ncols, eps);
405- } else {
406- const dim3 block_dims (1024 , 1 , 1 );
407- fused_rms_norm_f32<1024 ><<<nrows, block_dims, 0 , stream>>> (x, y, dst, ncols, eps);
408- }
409- }
410-
411364static void l2_norm_f32_cuda (
412365 const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
413366 const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) {
@@ -567,36 +520,6 @@ void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * d
567520 rms_norm_back_f32_cuda (grad_d, src0f_d, dst_d, ne00, nrows, eps, stream);
568521}
569522
570-
571- void ggml_cuda_op_fused_rms_norm (ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
572- if (!dst->src [1 ]) {
573- ggml_cuda_op_rms_norm (ctx, dst);
574- return ;
575- }
576- const ggml_tensor * src0 = dst->src [0 ];
577- const ggml_tensor * src1 = dst->src [1 ];
578- const float * src0_d = (const float *)src0->data ;
579- const float * src1_d = (const float *)src1->data ;
580- float * dst_d = (float *)dst->data ;
581- cudaStream_t stream = ctx.stream ();
582-
583- GGML_ASSERT (ggml_is_contiguous (src0));
584-
585- GGML_ASSERT (src0->type == GGML_TYPE_F32);
586- GGML_ASSERT (src1->type == GGML_TYPE_F32);
587- GGML_ASSERT ( dst->type == GGML_TYPE_F32);
588- GGML_ASSERT (src0->ne [0 ] == src1->ne [0 ]);
589- GGML_ASSERT (ggml_nrows (src1) == 1 );
590-
591- const int64_t ne00 = src0->ne [0 ];
592- const int64_t nrows = ggml_nrows (src0);
593-
594- float eps;
595- memcpy (&eps, dst->op_params , sizeof (float ));
596-
597- fused_rms_norm_f32_cuda (src0_d, src1_d, dst_d, ne00, nrows, eps, stream);
598- }
599-
600523void ggml_cuda_op_l2_norm (ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
601524 const ggml_tensor * src0 = dst->src [0 ];
602525 const float * src0_d = (const float *) src0->data ;
0 commit comments