@@ -3614,6 +3614,292 @@ static void ggml_compute_forward_swiglu(
36143614 }
36153615}
36163616
3617+ // ggml_compute_forward_geglu_erf
3618+
3619+ static void ggml_compute_forward_geglu_erf_f32 (
3620+ const ggml_compute_params * params,
3621+ ggml_tensor * dst) {
3622+
3623+ const ggml_tensor * src0 = dst->src [0 ];
3624+ const ggml_tensor * src1 = dst->src [1 ];
3625+ char * src0_d = (char *) src0->data ;
3626+ char * src1_d = (char *) (src1 ? src1->data : src0->data );
3627+ const size_t src0_o = src0->nb [1 ];
3628+ const size_t src1_o = src1 ? src1->nb [1 ] : src0->nb [1 ];
3629+
3630+ GGML_ASSERT (ggml_is_contiguous_1 (src0));
3631+ GGML_ASSERT (ggml_is_contiguous_1 (dst));
3632+
3633+ if (src1) {
3634+ GGML_ASSERT (ggml_is_contiguous_1 (src1));
3635+ GGML_ASSERT (src0->type == src1->type );
3636+ }
3637+
3638+ const int ith = params->ith ;
3639+ const int nth = params->nth ;
3640+
3641+ const int nc = src1 ? src0->ne [0 ] : src0->ne [0 ] / 2 ;
3642+ const int nr = ggml_nrows (src0);
3643+
3644+ GGML_ASSERT (dst->ne [0 ] == nc);
3645+ GGML_ASSERT (ggml_nrows (dst) == nr);
3646+
3647+ const int32_t swapped = ggml_get_op_params_i32 (dst, 1 );
3648+
3649+ // rows per thread
3650+ const int dr = (nr + nth - 1 )/nth;
3651+
3652+ // row range for this thread
3653+ const int ir0 = dr*ith;
3654+ const int ir1 = MIN (ir0 + dr, nr);
3655+
3656+ for (int i1 = ir0; i1 < ir1; i1++) {
3657+ float * src0_p = (float *) (src0_d + i1*src0_o);
3658+ float * src1_p = (float *) (src1_d + i1*src1_o);
3659+
3660+ if (!src1) {
3661+ src0_p += swapped ? nc : 0 ;
3662+ src1_p += swapped ? 0 : nc;
3663+ }
3664+
3665+ ggml_vec_geglu_erf_f32 (nc, (float *) ((char *) dst->data + i1*(dst->nb [1 ])), src0_p, src1_p);
3666+
3667+ #ifndef NDEBUG
3668+ for (int k = 0 ; k < nc; k++) {
3669+ const float x = ((float *) ((char *) dst->data + i1*( dst->nb [1 ])))[k];
3670+ GGML_UNUSED (x);
3671+ assert (!isnan (x));
3672+ assert (!isinf (x));
3673+ }
3674+ #endif
3675+ }
3676+ }
3677+
3678+ static void ggml_compute_forward_geglu_erf_f16 (
3679+ const ggml_compute_params * params,
3680+ ggml_tensor * dst) {
3681+
3682+ const ggml_tensor * src0 = dst->src [0 ];
3683+ const ggml_tensor * src1 = dst->src [1 ];
3684+ char * src0_d = (char *) src0->data ;
3685+ char * src1_d = (char *) (src1 ? src1->data : src0->data );
3686+ const size_t src0_o = src0->nb [1 ];
3687+ const size_t src1_o = src1 ? src1->nb [1 ] : src0->nb [1 ];
3688+
3689+ GGML_ASSERT (ggml_is_contiguous_1 (src0));
3690+ GGML_ASSERT (ggml_is_contiguous_1 (dst));
3691+
3692+ if (src1) {
3693+ GGML_ASSERT (ggml_is_contiguous_1 (src1));
3694+ GGML_ASSERT (src0->type == src1->type );
3695+ }
3696+
3697+ const int ith = params->ith ;
3698+ const int nth = params->nth ;
3699+
3700+ const int nc = src1 ? src0->ne [0 ] : src0->ne [0 ] / 2 ;
3701+ const int nr = ggml_nrows (src0);
3702+
3703+ GGML_ASSERT (dst->ne [0 ] == nc);
3704+ GGML_ASSERT (ggml_nrows (dst) == nr);
3705+
3706+ const int32_t swapped = ggml_get_op_params_i32 (dst, 1 );
3707+
3708+ // rows per thread
3709+ const int dr = (nr + nth - 1 )/nth;
3710+
3711+ // row range for this thread
3712+ const int ir0 = dr*ith;
3713+ const int ir1 = MIN (ir0 + dr, nr);
3714+
3715+ for (int i1 = ir0; i1 < ir1; i1++) {
3716+ ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
3717+ ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
3718+
3719+ if (!src1) {
3720+ src0_p += swapped ? nc : 0 ;
3721+ src1_p += swapped ? 0 : nc;
3722+ }
3723+
3724+ ggml_vec_geglu_erf_f16 (nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb [1 ])), src0_p, src1_p);
3725+
3726+ #ifndef NDEBUG
3727+ for (int k = 0 ; k < nc; k++) {
3728+ const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb [1 ])))[k];
3729+ const float v = GGML_FP16_TO_FP32 (x);
3730+ GGML_UNUSED (v);
3731+ assert (!isnan (v));
3732+ assert (!isinf (v));
3733+ }
3734+ #endif
3735+ }
3736+ }
3737+
3738+ static void ggml_compute_forward_geglu_erf (
3739+ const ggml_compute_params * params,
3740+ ggml_tensor * dst) {
3741+
3742+ const ggml_tensor * src0 = dst->src [0 ];
3743+
3744+ switch (src0->type ) {
3745+ case GGML_TYPE_F32:
3746+ {
3747+ ggml_compute_forward_geglu_erf_f32 (params, dst);
3748+ } break ;
3749+ case GGML_TYPE_F16:
3750+ {
3751+ ggml_compute_forward_geglu_erf_f16 (params, dst);
3752+ } break ;
3753+ default :
3754+ {
3755+ GGML_ABORT (" fatal error" );
3756+ }
3757+ }
3758+ }
3759+
3760+ // ggml_compute_forward_geglu_quick
3761+
3762+ static void ggml_compute_forward_geglu_quick_f32 (
3763+ const ggml_compute_params * params,
3764+ ggml_tensor * dst) {
3765+
3766+ const ggml_tensor * src0 = dst->src [0 ];
3767+ const ggml_tensor * src1 = dst->src [1 ];
3768+ char * src0_d = (char *) src0->data ;
3769+ char * src1_d = (char *) (src1 ? src1->data : src0->data );
3770+ const size_t src0_o = src0->nb [1 ];
3771+ const size_t src1_o = src1 ? src1->nb [1 ] : src0->nb [1 ];
3772+
3773+ GGML_ASSERT (ggml_is_contiguous_1 (src0));
3774+ GGML_ASSERT (ggml_is_contiguous_1 (dst));
3775+
3776+ if (src1) {
3777+ GGML_ASSERT (ggml_is_contiguous_1 (src1));
3778+ GGML_ASSERT (src0->type == src1->type );
3779+ }
3780+
3781+ const int ith = params->ith ;
3782+ const int nth = params->nth ;
3783+
3784+ const int nc = src1 ? src0->ne [0 ] : src0->ne [0 ] / 2 ;
3785+ const int nr = ggml_nrows (src0);
3786+
3787+ GGML_ASSERT (dst->ne [0 ] == nc);
3788+ GGML_ASSERT (ggml_nrows (dst) == nr);
3789+
3790+ const int32_t swapped = ggml_get_op_params_i32 (dst, 1 );
3791+
3792+ // rows per thread
3793+ const int dr = (nr + nth - 1 )/nth;
3794+
3795+ // row range for this thread
3796+ const int ir0 = dr*ith;
3797+ const int ir1 = MIN (ir0 + dr, nr);
3798+
3799+ for (int i1 = ir0; i1 < ir1; i1++) {
3800+ float * src0_p = (float *) (src0_d + i1*src0_o);
3801+ float * src1_p = (float *) (src1_d + i1*src1_o);
3802+
3803+ if (!src1) {
3804+ src0_p += swapped ? nc : 0 ;
3805+ src1_p += swapped ? 0 : nc;
3806+ }
3807+
3808+ ggml_vec_geglu_quick_f32 (nc, (float *) ((char *) dst->data + i1*(dst->nb [1 ])), src0_p, src1_p);
3809+
3810+ #ifndef NDEBUG
3811+ for (int k = 0 ; k < nc; k++) {
3812+ const float x = ((float *) ((char *) dst->data + i1*( dst->nb [1 ])))[k];
3813+ GGML_UNUSED (x);
3814+ assert (!isnan (x));
3815+ assert (!isinf (x));
3816+ }
3817+ #endif
3818+ }
3819+ }
3820+
3821+ static void ggml_compute_forward_geglu_quick_f16 (
3822+ const ggml_compute_params * params,
3823+ ggml_tensor * dst) {
3824+
3825+ const ggml_tensor * src0 = dst->src [0 ];
3826+ const ggml_tensor * src1 = dst->src [1 ];
3827+ char * src0_d = (char *) src0->data ;
3828+ char * src1_d = (char *) (src1 ? src1->data : src0->data );
3829+ const size_t src0_o = src0->nb [1 ];
3830+ const size_t src1_o = src1 ? src1->nb [1 ] : src0->nb [1 ];
3831+
3832+ GGML_ASSERT (ggml_is_contiguous_1 (src0));
3833+ GGML_ASSERT (ggml_is_contiguous_1 (dst));
3834+
3835+ if (src1) {
3836+ GGML_ASSERT (ggml_is_contiguous_1 (src1));
3837+ GGML_ASSERT (src0->type == src1->type );
3838+ }
3839+
3840+ const int ith = params->ith ;
3841+ const int nth = params->nth ;
3842+
3843+ const int nc = src1 ? src0->ne [0 ] : src0->ne [0 ] / 2 ;
3844+ const int nr = ggml_nrows (src0);
3845+
3846+ GGML_ASSERT (dst->ne [0 ] == nc);
3847+ GGML_ASSERT (ggml_nrows (dst) == nr);
3848+
3849+ const int32_t swapped = ggml_get_op_params_i32 (dst, 1 );
3850+
3851+ // rows per thread
3852+ const int dr = (nr + nth - 1 )/nth;
3853+
3854+ // row range for this thread
3855+ const int ir0 = dr*ith;
3856+ const int ir1 = MIN (ir0 + dr, nr);
3857+
3858+ for (int i1 = ir0; i1 < ir1; i1++) {
3859+ ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
3860+ ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
3861+
3862+ if (!src1) {
3863+ src0_p += swapped ? nc : 0 ;
3864+ src1_p += swapped ? 0 : nc;
3865+ }
3866+
3867+ ggml_vec_geglu_quick_f16 (nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb [1 ])), src0_p, src1_p);
3868+
3869+ #ifndef NDEBUG
3870+ for (int k = 0 ; k < nc; k++) {
3871+ const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb [1 ])))[k];
3872+ const float v = GGML_FP16_TO_FP32 (x);
3873+ GGML_UNUSED (v);
3874+ assert (!isnan (v));
3875+ assert (!isinf (v));
3876+ }
3877+ #endif
3878+ }
3879+ }
3880+
3881+ static void ggml_compute_forward_geglu_quick (
3882+ const ggml_compute_params * params,
3883+ ggml_tensor * dst) {
3884+
3885+ const ggml_tensor * src0 = dst->src [0 ];
3886+
3887+ switch (src0->type ) {
3888+ case GGML_TYPE_F32:
3889+ {
3890+ ggml_compute_forward_geglu_quick_f32 (params, dst);
3891+ } break ;
3892+ case GGML_TYPE_F16:
3893+ {
3894+ ggml_compute_forward_geglu_quick_f16 (params, dst);
3895+ } break ;
3896+ default :
3897+ {
3898+ GGML_ABORT (" fatal error" );
3899+ }
3900+ }
3901+ }
3902+
36173903// ggml_compute_forward_norm
36183904
36193905static void ggml_compute_forward_norm_f32 (
@@ -8779,6 +9065,14 @@ void ggml_compute_forward_glu(
87799065 {
87809066 ggml_compute_forward_swiglu (params, dst);
87819067 } break ;
9068+ case GGML_GLU_OP_GEGLU_ERF:
9069+ {
9070+ ggml_compute_forward_geglu_erf (params, dst);
9071+ } break ;
9072+ case GGML_GLU_OP_GEGLU_QUICK:
9073+ {
9074+ ggml_compute_forward_geglu_quick (params, dst);
9075+ } break ;
87829076 default :
87839077 {
87849078 GGML_ABORT (" fatal error" );
0 commit comments