@@ -3613,6 +3613,292 @@ static void ggml_compute_forward_swiglu(
36133613 }
36143614}
36153615
3616+ // ggml_compute_forward_geglu_erf
3617+
3618+ static void ggml_compute_forward_geglu_erf_f32 (
3619+ const ggml_compute_params * params,
3620+ ggml_tensor * dst) {
3621+
3622+ const ggml_tensor * src0 = dst->src [0 ];
3623+ const ggml_tensor * src1 = dst->src [1 ];
3624+ char * src0_d = (char *) src0->data ;
3625+ char * src1_d = (char *) (src1 ? src1->data : src0->data );
3626+ const size_t src0_o = src0->nb [1 ];
3627+ const size_t src1_o = src1 ? src1->nb [1 ] : src0->nb [1 ];
3628+
3629+ GGML_ASSERT (ggml_is_contiguous_1 (src0));
3630+ GGML_ASSERT (ggml_is_contiguous_1 (dst));
3631+
3632+ if (src1) {
3633+ GGML_ASSERT (ggml_is_contiguous_1 (src1));
3634+ GGML_ASSERT (src0->type == src1->type );
3635+ }
3636+
3637+ const int ith = params->ith ;
3638+ const int nth = params->nth ;
3639+
3640+ const int nc = src1 ? src0->ne [0 ] : src0->ne [0 ] / 2 ;
3641+ const int nr = ggml_nrows (src0);
3642+
3643+ GGML_ASSERT (dst->ne [0 ] == nc);
3644+ GGML_ASSERT (ggml_nrows (dst) == nr);
3645+
3646+ const int32_t swapped = ggml_get_op_params_i32 (dst, 1 );
3647+
3648+ // rows per thread
3649+ const int dr = (nr + nth - 1 )/nth;
3650+
3651+ // row range for this thread
3652+ const int ir0 = dr*ith;
3653+ const int ir1 = MIN (ir0 + dr, nr);
3654+
3655+ for (int i1 = ir0; i1 < ir1; i1++) {
3656+ float * src0_p = (float *) (src0_d + i1*src0_o);
3657+ float * src1_p = (float *) (src1_d + i1*src1_o);
3658+
3659+ if (!src1) {
3660+ src0_p += swapped ? nc : 0 ;
3661+ src1_p += swapped ? 0 : nc;
3662+ }
3663+
3664+ ggml_vec_geglu_erf_f32 (nc, (float *) ((char *) dst->data + i1*(dst->nb [1 ])), src0_p, src1_p);
3665+
3666+ #ifndef NDEBUG
3667+ for (int k = 0 ; k < nc; k++) {
3668+ const float x = ((float *) ((char *) dst->data + i1*( dst->nb [1 ])))[k];
3669+ GGML_UNUSED (x);
3670+ assert (!isnan (x));
3671+ assert (!isinf (x));
3672+ }
3673+ #endif
3674+ }
3675+ }
3676+
3677+ static void ggml_compute_forward_geglu_erf_f16 (
3678+ const ggml_compute_params * params,
3679+ ggml_tensor * dst) {
3680+
3681+ const ggml_tensor * src0 = dst->src [0 ];
3682+ const ggml_tensor * src1 = dst->src [1 ];
3683+ char * src0_d = (char *) src0->data ;
3684+ char * src1_d = (char *) (src1 ? src1->data : src0->data );
3685+ const size_t src0_o = src0->nb [1 ];
3686+ const size_t src1_o = src1 ? src1->nb [1 ] : src0->nb [1 ];
3687+
3688+ GGML_ASSERT (ggml_is_contiguous_1 (src0));
3689+ GGML_ASSERT (ggml_is_contiguous_1 (dst));
3690+
3691+ if (src1) {
3692+ GGML_ASSERT (ggml_is_contiguous_1 (src1));
3693+ GGML_ASSERT (src0->type == src1->type );
3694+ }
3695+
3696+ const int ith = params->ith ;
3697+ const int nth = params->nth ;
3698+
3699+ const int nc = src1 ? src0->ne [0 ] : src0->ne [0 ] / 2 ;
3700+ const int nr = ggml_nrows (src0);
3701+
3702+ GGML_ASSERT (dst->ne [0 ] == nc);
3703+ GGML_ASSERT (ggml_nrows (dst) == nr);
3704+
3705+ const int32_t swapped = ggml_get_op_params_i32 (dst, 1 );
3706+
3707+ // rows per thread
3708+ const int dr = (nr + nth - 1 )/nth;
3709+
3710+ // row range for this thread
3711+ const int ir0 = dr*ith;
3712+ const int ir1 = MIN (ir0 + dr, nr);
3713+
3714+ for (int i1 = ir0; i1 < ir1; i1++) {
3715+ ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
3716+ ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
3717+
3718+ if (!src1) {
3719+ src0_p += swapped ? nc : 0 ;
3720+ src1_p += swapped ? 0 : nc;
3721+ }
3722+
3723+ ggml_vec_geglu_erf_f16 (nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb [1 ])), src0_p, src1_p);
3724+
3725+ #ifndef NDEBUG
3726+ for (int k = 0 ; k < nc; k++) {
3727+ const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb [1 ])))[k];
3728+ const float v = GGML_FP16_TO_FP32 (x);
3729+ GGML_UNUSED (v);
3730+ assert (!isnan (v));
3731+ assert (!isinf (v));
3732+ }
3733+ #endif
3734+ }
3735+ }
3736+
3737+ static void ggml_compute_forward_geglu_erf (
3738+ const ggml_compute_params * params,
3739+ ggml_tensor * dst) {
3740+
3741+ const ggml_tensor * src0 = dst->src [0 ];
3742+
3743+ switch (src0->type ) {
3744+ case GGML_TYPE_F32:
3745+ {
3746+ ggml_compute_forward_geglu_erf_f32 (params, dst);
3747+ } break ;
3748+ case GGML_TYPE_F16:
3749+ {
3750+ ggml_compute_forward_geglu_erf_f16 (params, dst);
3751+ } break ;
3752+ default :
3753+ {
3754+ GGML_ABORT (" fatal error" );
3755+ }
3756+ }
3757+ }
3758+
3759+ // ggml_compute_forward_geglu_quick
3760+
3761+ static void ggml_compute_forward_geglu_quick_f32 (
3762+ const ggml_compute_params * params,
3763+ ggml_tensor * dst) {
3764+
3765+ const ggml_tensor * src0 = dst->src [0 ];
3766+ const ggml_tensor * src1 = dst->src [1 ];
3767+ char * src0_d = (char *) src0->data ;
3768+ char * src1_d = (char *) (src1 ? src1->data : src0->data );
3769+ const size_t src0_o = src0->nb [1 ];
3770+ const size_t src1_o = src1 ? src1->nb [1 ] : src0->nb [1 ];
3771+
3772+ GGML_ASSERT (ggml_is_contiguous_1 (src0));
3773+ GGML_ASSERT (ggml_is_contiguous_1 (dst));
3774+
3775+ if (src1) {
3776+ GGML_ASSERT (ggml_is_contiguous_1 (src1));
3777+ GGML_ASSERT (src0->type == src1->type );
3778+ }
3779+
3780+ const int ith = params->ith ;
3781+ const int nth = params->nth ;
3782+
3783+ const int nc = src1 ? src0->ne [0 ] : src0->ne [0 ] / 2 ;
3784+ const int nr = ggml_nrows (src0);
3785+
3786+ GGML_ASSERT (dst->ne [0 ] == nc);
3787+ GGML_ASSERT (ggml_nrows (dst) == nr);
3788+
3789+ const int32_t swapped = ggml_get_op_params_i32 (dst, 1 );
3790+
3791+ // rows per thread
3792+ const int dr = (nr + nth - 1 )/nth;
3793+
3794+ // row range for this thread
3795+ const int ir0 = dr*ith;
3796+ const int ir1 = MIN (ir0 + dr, nr);
3797+
3798+ for (int i1 = ir0; i1 < ir1; i1++) {
3799+ float * src0_p = (float *) (src0_d + i1*src0_o);
3800+ float * src1_p = (float *) (src1_d + i1*src1_o);
3801+
3802+ if (!src1) {
3803+ src0_p += swapped ? nc : 0 ;
3804+ src1_p += swapped ? 0 : nc;
3805+ }
3806+
3807+ ggml_vec_geglu_quick_f32 (nc, (float *) ((char *) dst->data + i1*(dst->nb [1 ])), src0_p, src1_p);
3808+
3809+ #ifndef NDEBUG
3810+ for (int k = 0 ; k < nc; k++) {
3811+ const float x = ((float *) ((char *) dst->data + i1*( dst->nb [1 ])))[k];
3812+ GGML_UNUSED (x);
3813+ assert (!isnan (x));
3814+ assert (!isinf (x));
3815+ }
3816+ #endif
3817+ }
3818+ }
3819+
3820+ static void ggml_compute_forward_geglu_quick_f16 (
3821+ const ggml_compute_params * params,
3822+ ggml_tensor * dst) {
3823+
3824+ const ggml_tensor * src0 = dst->src [0 ];
3825+ const ggml_tensor * src1 = dst->src [1 ];
3826+ char * src0_d = (char *) src0->data ;
3827+ char * src1_d = (char *) (src1 ? src1->data : src0->data );
3828+ const size_t src0_o = src0->nb [1 ];
3829+ const size_t src1_o = src1 ? src1->nb [1 ] : src0->nb [1 ];
3830+
3831+ GGML_ASSERT (ggml_is_contiguous_1 (src0));
3832+ GGML_ASSERT (ggml_is_contiguous_1 (dst));
3833+
3834+ if (src1) {
3835+ GGML_ASSERT (ggml_is_contiguous_1 (src1));
3836+ GGML_ASSERT (src0->type == src1->type );
3837+ }
3838+
3839+ const int ith = params->ith ;
3840+ const int nth = params->nth ;
3841+
3842+ const int nc = src1 ? src0->ne [0 ] : src0->ne [0 ] / 2 ;
3843+ const int nr = ggml_nrows (src0);
3844+
3845+ GGML_ASSERT (dst->ne [0 ] == nc);
3846+ GGML_ASSERT (ggml_nrows (dst) == nr);
3847+
3848+ const int32_t swapped = ggml_get_op_params_i32 (dst, 1 );
3849+
3850+ // rows per thread
3851+ const int dr = (nr + nth - 1 )/nth;
3852+
3853+ // row range for this thread
3854+ const int ir0 = dr*ith;
3855+ const int ir1 = MIN (ir0 + dr, nr);
3856+
3857+ for (int i1 = ir0; i1 < ir1; i1++) {
3858+ ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
3859+ ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
3860+
3861+ if (!src1) {
3862+ src0_p += swapped ? nc : 0 ;
3863+ src1_p += swapped ? 0 : nc;
3864+ }
3865+
3866+ ggml_vec_geglu_quick_f16 (nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb [1 ])), src0_p, src1_p);
3867+
3868+ #ifndef NDEBUG
3869+ for (int k = 0 ; k < nc; k++) {
3870+ const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb [1 ])))[k];
3871+ const float v = GGML_FP16_TO_FP32 (x);
3872+ GGML_UNUSED (v);
3873+ assert (!isnan (v));
3874+ assert (!isinf (v));
3875+ }
3876+ #endif
3877+ }
3878+ }
3879+
3880+ static void ggml_compute_forward_geglu_quick (
3881+ const ggml_compute_params * params,
3882+ ggml_tensor * dst) {
3883+
3884+ const ggml_tensor * src0 = dst->src [0 ];
3885+
3886+ switch (src0->type ) {
3887+ case GGML_TYPE_F32:
3888+ {
3889+ ggml_compute_forward_geglu_quick_f32 (params, dst);
3890+ } break ;
3891+ case GGML_TYPE_F16:
3892+ {
3893+ ggml_compute_forward_geglu_quick_f16 (params, dst);
3894+ } break ;
3895+ default :
3896+ {
3897+ GGML_ABORT (" fatal error" );
3898+ }
3899+ }
3900+ }
3901+
36163902// ggml_compute_forward_norm
36173903
36183904static void ggml_compute_forward_norm_f32 (
@@ -8502,6 +8788,14 @@ void ggml_compute_forward_glu(
85028788 {
85038789 ggml_compute_forward_swiglu (params, dst);
85048790 } break ;
8791+ case GGML_GLU_OP_GEGLU_ERF:
8792+ {
8793+ ggml_compute_forward_geglu_erf (params, dst);
8794+ } break ;
8795+ case GGML_GLU_OP_GEGLU_QUICK:
8796+ {
8797+ ggml_compute_forward_geglu_quick (params, dst);
8798+ } break ;
85058799 default :
85068800 {
85078801 GGML_ABORT (" fatal error" );
0 commit comments