Skip to content

Commit 6be1307

Browse files
authored
implement GEGLU_ERF and GEGLU_QUICK ops
1 parent a5d1fb6 commit 6be1307

File tree

17 files changed

+588
-24
lines changed

17 files changed

+588
-24
lines changed

ggml/include/ggml.h

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -549,6 +549,8 @@ extern "C" {
549549
GGML_GLU_OP_REGLU,
550550
GGML_GLU_OP_GEGLU,
551551
GGML_GLU_OP_SWIGLU,
552+
GGML_GLU_OP_GEGLU_ERF,
553+
GGML_GLU_OP_GEGLU_QUICK,
552554

553555
GGML_GLU_OP_COUNT,
554556
};
@@ -1136,6 +1138,22 @@ extern "C" {
11361138
struct ggml_context * ctx,
11371139
struct ggml_tensor * a);
11381140

1141+
GGML_API struct ggml_tensor * ggml_geglu_erf(
1142+
struct ggml_context * ctx,
1143+
struct ggml_tensor * a);
1144+
1145+
GGML_API struct ggml_tensor * ggml_geglu_erf_swapped(
1146+
struct ggml_context * ctx,
1147+
struct ggml_tensor * a);
1148+
1149+
GGML_API struct ggml_tensor * ggml_geglu_quick(
1150+
struct ggml_context * ctx,
1151+
struct ggml_tensor * a);
1152+
1153+
GGML_API struct ggml_tensor * ggml_geglu_quick_swapped(
1154+
struct ggml_context * ctx,
1155+
struct ggml_tensor * a);
1156+
11391157
// A: n columns, r rows,
11401158
// B: n columns, r rows,
11411159
GGML_API struct ggml_tensor * ggml_glu_split(
@@ -1159,6 +1177,16 @@ extern "C" {
11591177
struct ggml_tensor * a,
11601178
struct ggml_tensor * b);
11611179

1180+
GGML_API struct ggml_tensor * ggml_geglu_erf_split(
1181+
struct ggml_context * ctx,
1182+
struct ggml_tensor * a,
1183+
struct ggml_tensor * b);
1184+
1185+
GGML_API struct ggml_tensor * ggml_geglu_quick_split(
1186+
struct ggml_context * ctx,
1187+
struct ggml_tensor * a,
1188+
struct ggml_tensor * b);
1189+
11621190
// normalize along rows
11631191
GGML_API struct ggml_tensor * ggml_norm(
11641192
struct ggml_context * ctx,

ggml/src/ggml-cpu/ggml-cpu.c

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2168,6 +2168,8 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
21682168
case GGML_GLU_OP_REGLU:
21692169
case GGML_GLU_OP_GEGLU:
21702170
case GGML_GLU_OP_SWIGLU:
2171+
case GGML_GLU_OP_GEGLU_ERF:
2172+
case GGML_GLU_OP_GEGLU_QUICK:
21712173
{
21722174
n_tasks = n_threads;
21732175
} break;

ggml/src/ggml-cpu/ops.cpp

Lines changed: 294 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3613,6 +3613,292 @@ static void ggml_compute_forward_swiglu(
36133613
}
36143614
}
36153615

3616+
// ggml_compute_forward_geglu_erf
3617+
3618+
static void ggml_compute_forward_geglu_erf_f32(
3619+
const ggml_compute_params * params,
3620+
ggml_tensor * dst) {
3621+
3622+
const ggml_tensor * src0 = dst->src[0];
3623+
const ggml_tensor * src1 = dst->src[1];
3624+
char * src0_d = (char *) src0->data;
3625+
char * src1_d = (char *) (src1 ? src1->data : src0->data);
3626+
const size_t src0_o = src0->nb[1];
3627+
const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
3628+
3629+
GGML_ASSERT(ggml_is_contiguous_1(src0));
3630+
GGML_ASSERT(ggml_is_contiguous_1(dst));
3631+
3632+
if (src1) {
3633+
GGML_ASSERT(ggml_is_contiguous_1(src1));
3634+
GGML_ASSERT(src0->type == src1->type);
3635+
}
3636+
3637+
const int ith = params->ith;
3638+
const int nth = params->nth;
3639+
3640+
const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
3641+
const int nr = ggml_nrows(src0);
3642+
3643+
GGML_ASSERT(dst->ne[0] == nc);
3644+
GGML_ASSERT(ggml_nrows(dst) == nr);
3645+
3646+
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
3647+
3648+
// rows per thread
3649+
const int dr = (nr + nth - 1)/nth;
3650+
3651+
// row range for this thread
3652+
const int ir0 = dr*ith;
3653+
const int ir1 = MIN(ir0 + dr, nr);
3654+
3655+
for (int i1 = ir0; i1 < ir1; i1++) {
3656+
float * src0_p = (float *) (src0_d + i1*src0_o);
3657+
float * src1_p = (float *) (src1_d + i1*src1_o);
3658+
3659+
if (!src1) {
3660+
src0_p += swapped ? nc : 0;
3661+
src1_p += swapped ? 0 : nc;
3662+
}
3663+
3664+
ggml_vec_geglu_erf_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3665+
3666+
#ifndef NDEBUG
3667+
for (int k = 0; k < nc; k++) {
3668+
const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
3669+
GGML_UNUSED(x);
3670+
assert(!isnan(x));
3671+
assert(!isinf(x));
3672+
}
3673+
#endif
3674+
}
3675+
}
3676+
3677+
static void ggml_compute_forward_geglu_erf_f16(
3678+
const ggml_compute_params * params,
3679+
ggml_tensor * dst) {
3680+
3681+
const ggml_tensor * src0 = dst->src[0];
3682+
const ggml_tensor * src1 = dst->src[1];
3683+
char * src0_d = (char *) src0->data;
3684+
char * src1_d = (char *) (src1 ? src1->data : src0->data);
3685+
const size_t src0_o = src0->nb[1];
3686+
const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
3687+
3688+
GGML_ASSERT(ggml_is_contiguous_1(src0));
3689+
GGML_ASSERT(ggml_is_contiguous_1(dst));
3690+
3691+
if (src1) {
3692+
GGML_ASSERT(ggml_is_contiguous_1(src1));
3693+
GGML_ASSERT(src0->type == src1->type);
3694+
}
3695+
3696+
const int ith = params->ith;
3697+
const int nth = params->nth;
3698+
3699+
const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
3700+
const int nr = ggml_nrows(src0);
3701+
3702+
GGML_ASSERT(dst->ne[0] == nc);
3703+
GGML_ASSERT(ggml_nrows(dst) == nr);
3704+
3705+
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
3706+
3707+
// rows per thread
3708+
const int dr = (nr + nth - 1)/nth;
3709+
3710+
// row range for this thread
3711+
const int ir0 = dr*ith;
3712+
const int ir1 = MIN(ir0 + dr, nr);
3713+
3714+
for (int i1 = ir0; i1 < ir1; i1++) {
3715+
ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
3716+
ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
3717+
3718+
if (!src1) {
3719+
src0_p += swapped ? nc : 0;
3720+
src1_p += swapped ? 0 : nc;
3721+
}
3722+
3723+
ggml_vec_geglu_erf_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3724+
3725+
#ifndef NDEBUG
3726+
for (int k = 0; k < nc; k++) {
3727+
const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
3728+
const float v = GGML_FP16_TO_FP32(x);
3729+
GGML_UNUSED(v);
3730+
assert(!isnan(v));
3731+
assert(!isinf(v));
3732+
}
3733+
#endif
3734+
}
3735+
}
3736+
3737+
static void ggml_compute_forward_geglu_erf(
3738+
const ggml_compute_params * params,
3739+
ggml_tensor * dst) {
3740+
3741+
const ggml_tensor * src0 = dst->src[0];
3742+
3743+
switch (src0->type) {
3744+
case GGML_TYPE_F32:
3745+
{
3746+
ggml_compute_forward_geglu_erf_f32(params, dst);
3747+
} break;
3748+
case GGML_TYPE_F16:
3749+
{
3750+
ggml_compute_forward_geglu_erf_f16(params, dst);
3751+
} break;
3752+
default:
3753+
{
3754+
GGML_ABORT("fatal error");
3755+
}
3756+
}
3757+
}
3758+
3759+
// ggml_compute_forward_geglu_quick
3760+
3761+
static void ggml_compute_forward_geglu_quick_f32(
3762+
const ggml_compute_params * params,
3763+
ggml_tensor * dst) {
3764+
3765+
const ggml_tensor * src0 = dst->src[0];
3766+
const ggml_tensor * src1 = dst->src[1];
3767+
char * src0_d = (char *) src0->data;
3768+
char * src1_d = (char *) (src1 ? src1->data : src0->data);
3769+
const size_t src0_o = src0->nb[1];
3770+
const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
3771+
3772+
GGML_ASSERT(ggml_is_contiguous_1(src0));
3773+
GGML_ASSERT(ggml_is_contiguous_1(dst));
3774+
3775+
if (src1) {
3776+
GGML_ASSERT(ggml_is_contiguous_1(src1));
3777+
GGML_ASSERT(src0->type == src1->type);
3778+
}
3779+
3780+
const int ith = params->ith;
3781+
const int nth = params->nth;
3782+
3783+
const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
3784+
const int nr = ggml_nrows(src0);
3785+
3786+
GGML_ASSERT(dst->ne[0] == nc);
3787+
GGML_ASSERT(ggml_nrows(dst) == nr);
3788+
3789+
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
3790+
3791+
// rows per thread
3792+
const int dr = (nr + nth - 1)/nth;
3793+
3794+
// row range for this thread
3795+
const int ir0 = dr*ith;
3796+
const int ir1 = MIN(ir0 + dr, nr);
3797+
3798+
for (int i1 = ir0; i1 < ir1; i1++) {
3799+
float * src0_p = (float *) (src0_d + i1*src0_o);
3800+
float * src1_p = (float *) (src1_d + i1*src1_o);
3801+
3802+
if (!src1) {
3803+
src0_p += swapped ? nc : 0;
3804+
src1_p += swapped ? 0 : nc;
3805+
}
3806+
3807+
ggml_vec_geglu_quick_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3808+
3809+
#ifndef NDEBUG
3810+
for (int k = 0; k < nc; k++) {
3811+
const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
3812+
GGML_UNUSED(x);
3813+
assert(!isnan(x));
3814+
assert(!isinf(x));
3815+
}
3816+
#endif
3817+
}
3818+
}
3819+
3820+
static void ggml_compute_forward_geglu_quick_f16(
3821+
const ggml_compute_params * params,
3822+
ggml_tensor * dst) {
3823+
3824+
const ggml_tensor * src0 = dst->src[0];
3825+
const ggml_tensor * src1 = dst->src[1];
3826+
char * src0_d = (char *) src0->data;
3827+
char * src1_d = (char *) (src1 ? src1->data : src0->data);
3828+
const size_t src0_o = src0->nb[1];
3829+
const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
3830+
3831+
GGML_ASSERT(ggml_is_contiguous_1(src0));
3832+
GGML_ASSERT(ggml_is_contiguous_1(dst));
3833+
3834+
if (src1) {
3835+
GGML_ASSERT(ggml_is_contiguous_1(src1));
3836+
GGML_ASSERT(src0->type == src1->type);
3837+
}
3838+
3839+
const int ith = params->ith;
3840+
const int nth = params->nth;
3841+
3842+
const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
3843+
const int nr = ggml_nrows(src0);
3844+
3845+
GGML_ASSERT(dst->ne[0] == nc);
3846+
GGML_ASSERT(ggml_nrows(dst) == nr);
3847+
3848+
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
3849+
3850+
// rows per thread
3851+
const int dr = (nr + nth - 1)/nth;
3852+
3853+
// row range for this thread
3854+
const int ir0 = dr*ith;
3855+
const int ir1 = MIN(ir0 + dr, nr);
3856+
3857+
for (int i1 = ir0; i1 < ir1; i1++) {
3858+
ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
3859+
ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
3860+
3861+
if (!src1) {
3862+
src0_p += swapped ? nc : 0;
3863+
src1_p += swapped ? 0 : nc;
3864+
}
3865+
3866+
ggml_vec_geglu_quick_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3867+
3868+
#ifndef NDEBUG
3869+
for (int k = 0; k < nc; k++) {
3870+
const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
3871+
const float v = GGML_FP16_TO_FP32(x);
3872+
GGML_UNUSED(v);
3873+
assert(!isnan(v));
3874+
assert(!isinf(v));
3875+
}
3876+
#endif
3877+
}
3878+
}
3879+
3880+
static void ggml_compute_forward_geglu_quick(
3881+
const ggml_compute_params * params,
3882+
ggml_tensor * dst) {
3883+
3884+
const ggml_tensor * src0 = dst->src[0];
3885+
3886+
switch (src0->type) {
3887+
case GGML_TYPE_F32:
3888+
{
3889+
ggml_compute_forward_geglu_quick_f32(params, dst);
3890+
} break;
3891+
case GGML_TYPE_F16:
3892+
{
3893+
ggml_compute_forward_geglu_quick_f16(params, dst);
3894+
} break;
3895+
default:
3896+
{
3897+
GGML_ABORT("fatal error");
3898+
}
3899+
}
3900+
}
3901+
36163902
// ggml_compute_forward_norm
36173903

36183904
static void ggml_compute_forward_norm_f32(
@@ -8502,6 +8788,14 @@ void ggml_compute_forward_glu(
85028788
{
85038789
ggml_compute_forward_swiglu(params, dst);
85048790
} break;
8791+
case GGML_GLU_OP_GEGLU_ERF:
8792+
{
8793+
ggml_compute_forward_geglu_erf(params, dst);
8794+
} break;
8795+
case GGML_GLU_OP_GEGLU_QUICK:
8796+
{
8797+
ggml_compute_forward_geglu_quick(params, dst);
8798+
} break;
85058799
default:
85068800
{
85078801
GGML_ABORT("fatal error");

0 commit comments

Comments
 (0)