@@ -619,84 +619,3 @@ void ggml_cuda_op_fused_add_add_rms_norm(ggml_backend_cuda_context & ctx,
619619 fused_add_add_rms_norm_f32_cuda ((const float *)add1->src [0 ]->data , (const float *)add1->src [1 ]->data , (const float *)add2->src [1 ]->data ,
620620 src1_d, (float *)add2->data , dst_d, ne00, nrows, eps, stream);
621621}
622-
623- template <int block_size>
624- static __global__ void fused_rms_rms_norm_f32 (int ncols, int nrows1, int nrows2, size_t nb1, size_t nb2, float eps,
625- const char *x1, const char * x2, const float * c1, const float * c2, float * y1, float * y2) {
626- const int row = blockIdx .x *blockDim .y + threadIdx .y ;
627- const int tid = threadIdx .x ;
628-
629- auto x_row = (const float *)(row < nrows1 ? x1 + row*nb1 : x2 + (row - nrows1)*nb2);
630-
631- float tmp = 0 .0f ; // partial sum for thread in warp
632-
633- for (int col = tid; col < ncols; col += block_size) {
634- const float xi = x_row[col];
635- tmp += xi * xi;
636- }
637-
638- // sum up partial sums
639- tmp = warp_reduce_sum (tmp);
640- if (block_size > WARP_SIZE) {
641- __shared__ float s_sum[32 ];
642- int warp_id = threadIdx .x / WARP_SIZE;
643- int lane_id = threadIdx .x % WARP_SIZE;
644- if (lane_id == 0 ) {
645- s_sum[warp_id] = tmp;
646- }
647- __syncthreads ();
648- tmp = lane_id < block_size/WARP_SIZE ? s_sum[lane_id] : 0 .0f ;
649- tmp = warp_reduce_sum (tmp);
650- }
651-
652- const float mean = tmp / ncols;
653- const float scale = rsqrtf (mean + eps);
654-
655- auto dst = row < nrows1 ? y1 + row*ncols : y2 + (row - nrows1)*ncols;
656- auto c = row < nrows1 ? c1 : c2;
657-
658- for (int col = tid; col < ncols; col += block_size) {
659- dst[col] = scale * c[col] * x_row[col];
660- }
661- }
662-
663- static void fused_rms_rms_norm_f32_cuda (int ncols, int nrows1, int nrows2, size_t nb1, size_t nb2, float eps,
664- const char * x1, const char * x2, const float * c1, const float * c2, float * y1, float * y2, cudaStream_t stream) {
665- GGML_ASSERT (ncols % WARP_SIZE == 0 );
666- int nrows = nrows1 + nrows2;
667- if (ncols < 1024 ) {
668- const dim3 block_dims (256 , 1 , 1 );
669- fused_rms_rms_norm_f32<256 ><<<nrows, block_dims, 0 , stream>>> (ncols, nrows1, nrows2, nb1, nb2, eps, x1, x2, c1, c2, y1, y2);
670- } else {
671- const dim3 block_dims (1024 , 1 , 1 );
672- fused_rms_rms_norm_f32<1024 ><<<nrows, block_dims, 0 , stream>>> (ncols, nrows1, nrows2, nb1, nb2, eps, x1, x2, c1, c2, y1, y2);
673- }
674- }
675-
676- void ggml_cuda_op_fused_rms_rms_norm ([[maybe_unused]] ggml_backend_cuda_context & ctx, [[maybe_unused]] ggml_tensor * rms1, [[maybe_unused]] ggml_tensor * rms2) {
677- GGML_ASSERT (rms1->ne [2 ] == 1 && rms1->ne [3 ] == 1 );
678- GGML_ASSERT (rms2->ne [2 ] == 1 && rms2->ne [3 ] == 1 );
679- GGML_ASSERT (rms1->ne [0 ] == rms2->ne [0 ]);
680- GGML_ASSERT (rms1->type == GGML_TYPE_F32);
681- GGML_ASSERT (rms2->type == GGML_TYPE_F32);
682- GGML_ASSERT (rms1->src [0 ]->type == GGML_TYPE_F32);
683- GGML_ASSERT (rms2->src [0 ]->type == GGML_TYPE_F32);
684- GGML_ASSERT (rms1->src [0 ]->ne [0 ] == rms1->src [1 ]->ne [0 ]);
685- GGML_ASSERT (rms2->src [0 ]->ne [0 ] == rms2->src [1 ]->ne [0 ]);
686- GGML_ASSERT (ggml_nrows (rms1->src [1 ]) == 1 );
687- GGML_ASSERT (ggml_nrows (rms2->src [1 ]) == 1 );
688- GGML_ASSERT (rms1->src [1 ]->type == GGML_TYPE_F32);
689- GGML_ASSERT (rms2->src [1 ]->type == GGML_TYPE_F32);
690-
691- float eps1, eps2;
692- memcpy (&eps1, rms1->op_params , sizeof (float ));
693- memcpy (&eps2, rms2->op_params , sizeof (float ));
694- GGML_ASSERT (eps1 == eps2);
695-
696- fused_rms_rms_norm_f32_cuda (rms1->ne [0 ], rms1->ne [1 ], rms2->ne [1 ], rms1->nb [1 ], rms2->nb [1 ], eps1,
697- (const char *)rms1->src [0 ]->data , (const char *)rms2->src [0 ]->data ,
698- (const float *)rms1->src [1 ]->data , (const float *)rms2->src [1 ]->data ,
699- (float *)rms1->data , (float *)rms2->data , ctx.stream ());
700-
701-
702- }
0 commit comments