@@ -3201,14 +3201,24 @@ static void ggml_compute_forward_reglu_f32(
32013201 ggml_tensor * dst) {
32023202
32033203 const ggml_tensor * src0 = dst->src [0 ];
3204+ const ggml_tensor * src1 = dst->src [1 ];
3205+ char * src0_d = (char *) src0->data ;
3206+ char * src1_d = (char *) (src1 ? src1->data : src0->data );
3207+ const size_t src0_o = src0->nb [1 ];
3208+ const size_t src1_o = src1 ? src1->nb [1 ] : src0->nb [1 ];
32043209
32053210 GGML_ASSERT (ggml_is_contiguous_1 (src0));
32063211 GGML_ASSERT (ggml_is_contiguous_1 (dst));
32073212
3213+ if (src1) {
3214+ GGML_ASSERT (ggml_is_contiguous_1 (src1));
3215+ GGML_ASSERT (src0->type == src1->type );
3216+ }
3217+
32083218 const int ith = params->ith ;
32093219 const int nth = params->nth ;
32103220
3211- const int nc = src0->ne [0 ] / 2 ;
3221+ const int nc = src1 ? src0-> ne [ 0 ] : src0->ne [0 ] / 2 ;
32123222 const int nr = ggml_nrows (src0);
32133223
32143224 GGML_ASSERT (dst->ne [0 ] == nc);
@@ -3224,10 +3234,15 @@ static void ggml_compute_forward_reglu_f32(
32243234 const int ir1 = MIN (ir0 + dr, nr);
32253235
32263236 for (int i1 = ir0; i1 < ir1; i1++) {
3227- ggml_vec_reglu_f32 (nc,
3228- (float *) ((char *) dst->data + i1*( dst->nb [1 ])),
3229- (float *) ((char *) src0->data + i1*(src0->nb [1 ])) + (swapped ? nc : 0 ),
3230- (float *) ((char *) src0->data + i1*(src0->nb [1 ])) + (swapped ? 0 : nc));
3237+ float * src0_p = (float *) (src0_d + i1*src0_o);
3238+ float * src1_p = (float *) (src1_d + i1*src1_o);
3239+
3240+ if (!src1) {
3241+ src0_p += swapped ? nc : 0 ;
3242+ src1_p += swapped ? 0 : nc;
3243+ }
3244+
3245+ ggml_vec_reglu_f32 (nc, (float *) ((char *) dst->data + i1*(dst->nb [1 ])), src0_p, src1_p);
32313246
32323247#ifndef NDEBUG
32333248 for (int k = 0 ; k < nc; k++) {
@@ -3245,14 +3260,24 @@ static void ggml_compute_forward_reglu_f16(
32453260 ggml_tensor * dst) {
32463261
32473262 const ggml_tensor * src0 = dst->src [0 ];
3263+ const ggml_tensor * src1 = dst->src [1 ];
3264+ char * src0_d = (char *) src0->data ;
3265+ char * src1_d = (char *) (src1 ? src1->data : src0->data );
3266+ const size_t src0_o = src0->nb [1 ];
3267+ const size_t src1_o = src1 ? src1->nb [1 ] : src0->nb [1 ];
32483268
32493269 GGML_ASSERT (ggml_is_contiguous_1 (src0));
32503270 GGML_ASSERT (ggml_is_contiguous_1 (dst));
32513271
3272+ if (src1) {
3273+ GGML_ASSERT (ggml_is_contiguous_1 (src1));
3274+ GGML_ASSERT (src0->type == src1->type );
3275+ }
3276+
32523277 const int ith = params->ith ;
32533278 const int nth = params->nth ;
32543279
3255- const int nc = src0->ne [0 ] / 2 ;
3280+ const int nc = src1 ? src0-> ne [ 0 ] : src0->ne [0 ] / 2 ;
32563281 const int nr = ggml_nrows (src0);
32573282
32583283 GGML_ASSERT (dst->ne [0 ] == nc);
@@ -3268,10 +3293,15 @@ static void ggml_compute_forward_reglu_f16(
32683293 const int ir1 = MIN (ir0 + dr, nr);
32693294
32703295 for (int i1 = ir0; i1 < ir1; i1++) {
3271- ggml_vec_reglu_f16 (nc,
3272- (ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb [1 ])),
3273- (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb [1 ])) + (swapped ? nc : 0 ),
3274- (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb [1 ])) + (swapped ? 0 : nc));
3296+ ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
3297+ ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
3298+
3299+ if (!src1) {
3300+ src0_p += swapped ? nc : 0 ;
3301+ src1_p += swapped ? 0 : nc;
3302+ }
3303+
3304+ ggml_vec_reglu_f16 (nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb [1 ])), src0_p, src1_p);
32753305
32763306#ifndef NDEBUG
32773307 for (int k = 0 ; k < nc; k++) {
@@ -3314,14 +3344,24 @@ static void ggml_compute_forward_geglu_f32(
33143344 ggml_tensor * dst) {
33153345
33163346 const ggml_tensor * src0 = dst->src [0 ];
3347+ const ggml_tensor * src1 = dst->src [1 ];
3348+ char * src0_d = (char *) src0->data ;
3349+ char * src1_d = (char *) (src1 ? src1->data : src0->data );
3350+ const size_t src0_o = src0->nb [1 ];
3351+ const size_t src1_o = src1 ? src1->nb [1 ] : src0->nb [1 ];
33173352
33183353 GGML_ASSERT (ggml_is_contiguous_1 (src0));
33193354 GGML_ASSERT (ggml_is_contiguous_1 (dst));
33203355
3356+ if (src1) {
3357+ GGML_ASSERT (ggml_is_contiguous_1 (src1));
3358+ GGML_ASSERT (src0->type == src1->type );
3359+ }
3360+
33213361 const int ith = params->ith ;
33223362 const int nth = params->nth ;
33233363
3324- const int nc = src0->ne [0 ] / 2 ;
3364+ const int nc = src1 ? src0-> ne [ 0 ] : src0->ne [0 ] / 2 ;
33253365 const int nr = ggml_nrows (src0);
33263366
33273367 GGML_ASSERT (dst->ne [0 ] == nc);
@@ -3337,10 +3377,15 @@ static void ggml_compute_forward_geglu_f32(
33373377 const int ir1 = MIN (ir0 + dr, nr);
33383378
33393379 for (int i1 = ir0; i1 < ir1; i1++) {
3340- ggml_vec_geglu_f32 (nc,
3341- (float *) ((char *) dst->data + i1*( dst->nb [1 ])),
3342- (float *) ((char *) src0->data + i1*(src0->nb [1 ])) + (swapped ? nc : 0 ),
3343- (float *) ((char *) src0->data + i1*(src0->nb [1 ])) + (swapped ? 0 : nc));
3380+ float * src0_p = (float *) (src0_d + i1*src0_o);
3381+ float * src1_p = (float *) (src1_d + i1*src1_o);
3382+
3383+ if (!src1) {
3384+ src0_p += swapped ? nc : 0 ;
3385+ src1_p += swapped ? 0 : nc;
3386+ }
3387+
3388+ ggml_vec_geglu_f32 (nc, (float *) ((char *) dst->data + i1*(dst->nb [1 ])), src0_p, src1_p);
33443389
33453390#ifndef NDEBUG
33463391 for (int k = 0 ; k < nc; k++) {
@@ -3358,14 +3403,24 @@ static void ggml_compute_forward_geglu_f16(
33583403 ggml_tensor * dst) {
33593404
33603405 const ggml_tensor * src0 = dst->src [0 ];
3406+ const ggml_tensor * src1 = dst->src [1 ];
3407+ char * src0_d = (char *) src0->data ;
3408+ char * src1_d = (char *) (src1 ? src1->data : src0->data );
3409+ const size_t src0_o = src0->nb [1 ];
3410+ const size_t src1_o = src1 ? src1->nb [1 ] : src0->nb [1 ];
33613411
33623412 GGML_ASSERT (ggml_is_contiguous_1 (src0));
33633413 GGML_ASSERT (ggml_is_contiguous_1 (dst));
33643414
3415+ if (src1) {
3416+ GGML_ASSERT (ggml_is_contiguous_1 (src1));
3417+ GGML_ASSERT (src0->type == src1->type );
3418+ }
3419+
33653420 const int ith = params->ith ;
33663421 const int nth = params->nth ;
33673422
3368- const int nc = src0->ne [0 ] / 2 ;
3423+ const int nc = src1 ? src0-> ne [ 0 ] : src0->ne [0 ] / 2 ;
33693424 const int nr = ggml_nrows (src0);
33703425
33713426 GGML_ASSERT (dst->ne [0 ] == nc);
@@ -3381,10 +3436,15 @@ static void ggml_compute_forward_geglu_f16(
33813436 const int ir1 = MIN (ir0 + dr, nr);
33823437
33833438 for (int i1 = ir0; i1 < ir1; i1++) {
3384- ggml_vec_geglu_f16 (nc,
3385- (ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb [1 ])),
3386- (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb [1 ])) + (swapped ? nc : 0 ),
3387- (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb [1 ])) + (swapped ? 0 : nc));
3439+ ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
3440+ ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
3441+
3442+ if (!src1) {
3443+ src0_p += swapped ? nc : 0 ;
3444+ src1_p += swapped ? 0 : nc;
3445+ }
3446+
3447+ ggml_vec_geglu_f16 (nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb [1 ])), src0_p, src1_p);
33883448
33893449#ifndef NDEBUG
33903450 for (int k = 0 ; k < nc; k++) {
@@ -3427,14 +3487,24 @@ static void ggml_compute_forward_swiglu_f32(
34273487 ggml_tensor * dst) {
34283488
34293489 const ggml_tensor * src0 = dst->src [0 ];
3490+ const ggml_tensor * src1 = dst->src [1 ];
3491+ char * src0_d = (char *) src0->data ;
3492+ char * src1_d = (char *) (src1 ? src1->data : src0->data );
3493+ const size_t src0_o = src0->nb [1 ];
3494+ const size_t src1_o = src1 ? src1->nb [1 ] : src0->nb [1 ];
34303495
34313496 GGML_ASSERT (ggml_is_contiguous_1 (src0));
34323497 GGML_ASSERT (ggml_is_contiguous_1 (dst));
34333498
3499+ if (src1) {
3500+ GGML_ASSERT (ggml_is_contiguous_1 (src1));
3501+ GGML_ASSERT (src0->type == src1->type );
3502+ }
3503+
34343504 const int ith = params->ith ;
34353505 const int nth = params->nth ;
34363506
3437- const int nc = src0->ne [0 ] / 2 ;
3507+ const int nc = src1 ? src0-> ne [ 0 ] : src0->ne [0 ] / 2 ;
34383508 const int nr = ggml_nrows (src0);
34393509
34403510 GGML_ASSERT (dst->ne [0 ] == nc);
@@ -3450,10 +3520,15 @@ static void ggml_compute_forward_swiglu_f32(
34503520 const int ir1 = MIN (ir0 + dr, nr);
34513521
34523522 for (int i1 = ir0; i1 < ir1; i1++) {
3453- ggml_vec_swiglu_f32 (nc,
3454- (float *) ((char *) dst->data + i1*( dst->nb [1 ])),
3455- (float *) ((char *) src0->data + i1*(src0->nb [1 ])) + (swapped ? nc : 0 ),
3456- (float *) ((char *) src0->data + i1*(src0->nb [1 ])) + (swapped ? 0 : nc));
3523+ float * src0_p = (float *) (src0_d + i1*src0_o);
3524+ float * src1_p = (float *) (src1_d + i1*src1_o);
3525+
3526+ if (!src1) {
3527+ src0_p += swapped ? nc : 0 ;
3528+ src1_p += swapped ? 0 : nc;
3529+ }
3530+
3531+ ggml_vec_swiglu_f32 (nc, (float *) ((char *) dst->data + i1*(dst->nb [1 ])), src0_p, src1_p);
34573532
34583533#ifndef NDEBUG
34593534 for (int k = 0 ; k < nc; k++) {
@@ -3471,14 +3546,24 @@ static void ggml_compute_forward_swiglu_f16(
34713546 ggml_tensor * dst) {
34723547
34733548 const ggml_tensor * src0 = dst->src [0 ];
3549+ const ggml_tensor * src1 = dst->src [1 ];
3550+ char * src0_d = (char *) src0->data ;
3551+ char * src1_d = (char *) (src1 ? src1->data : src0->data );
3552+ const size_t src0_o = src0->nb [1 ];
3553+ const size_t src1_o = src1 ? src1->nb [1 ] : src0->nb [1 ];
34743554
34753555 GGML_ASSERT (ggml_is_contiguous_1 (src0));
34763556 GGML_ASSERT (ggml_is_contiguous_1 (dst));
34773557
3558+ if (src1) {
3559+ GGML_ASSERT (ggml_is_contiguous_1 (src1));
3560+ GGML_ASSERT (src0->type == src1->type );
3561+ }
3562+
34783563 const int ith = params->ith ;
34793564 const int nth = params->nth ;
34803565
3481- const int nc = src0->ne [0 ] / 2 ;
3566+ const int nc = src1 ? src0-> ne [ 0 ] : src0->ne [0 ] / 2 ;
34823567 const int nr = ggml_nrows (src0);
34833568
34843569 GGML_ASSERT (dst->ne [0 ] == nc);
@@ -3494,10 +3579,15 @@ static void ggml_compute_forward_swiglu_f16(
34943579 const int ir1 = MIN (ir0 + dr, nr);
34953580
34963581 for (int i1 = ir0; i1 < ir1; i1++) {
3497- ggml_vec_swiglu_f16 (nc,
3498- (ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb [1 ])),
3499- (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb [1 ])) + (swapped ? nc : 0 ),
3500- (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb [1 ])) + (swapped ? 0 : nc));
3582+ ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
3583+ ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
3584+
3585+ if (!src1) {
3586+ src0_p += swapped ? nc : 0 ;
3587+ src1_p += swapped ? 0 : nc;
3588+ }
3589+
3590+ ggml_vec_swiglu_f16 (nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb [1 ])), src0_p, src1_p);
35013591
35023592#ifndef NDEBUG
35033593 for (int k = 0 ; k < nc; k++) {
0 commit comments