@@ -3214,6 +3214,8 @@ static void ggml_compute_forward_reglu_f32(
32143214 GGML_ASSERT (dst->ne [0 ] == nc);
32153215 GGML_ASSERT (ggml_nrows (dst) == nr);
32163216
3217+ const int32_t swapped = ggml_get_op_params_i32 (dst, 1 );
3218+
32173219 // rows per thread
32183220 const int dr = (nr + nth - 1 )/nth;
32193221
@@ -3224,7 +3226,8 @@ static void ggml_compute_forward_reglu_f32(
32243226 for (int i1 = ir0; i1 < ir1; i1++) {
32253227 ggml_vec_reglu_f32 (nc,
32263228 (float *) ((char *) dst->data + i1*( dst->nb [1 ])),
3227- (float *) ((char *) src0->data + i1*(src0->nb [1 ])));
3229+ (float *) ((char *) src0->data + i1*(src0->nb [1 ])) + (swapped ? nc : 0 ),
3230+ (float *) ((char *) src0->data + i1*(src0->nb [1 ])) + (swapped ? 0 : nc));
32283231
32293232#ifndef NDEBUG
32303233 for (int k = 0 ; k < nc; k++) {
@@ -3255,6 +3258,8 @@ static void ggml_compute_forward_reglu_f16(
32553258 GGML_ASSERT (dst->ne [0 ] == nc);
32563259 GGML_ASSERT (ggml_nrows (dst) == nr);
32573260
3261+ const int32_t swapped = ggml_get_op_params_i32 (dst, 1 );
3262+
32583263 // rows per thread
32593264 const int dr = (nr + nth - 1 )/nth;
32603265
@@ -3265,7 +3270,8 @@ static void ggml_compute_forward_reglu_f16(
32653270 for (int i1 = ir0; i1 < ir1; i1++) {
32663271 ggml_vec_reglu_f16 (nc,
32673272 (ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb [1 ])),
3268- (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb [1 ])));
3273+ (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb [1 ])) + (swapped ? nc : 0 ),
3274+ (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb [1 ])) + (swapped ? 0 : nc));
32693275
32703276#ifndef NDEBUG
32713277 for (int k = 0 ; k < nc; k++) {
@@ -3321,6 +3327,8 @@ static void ggml_compute_forward_geglu_f32(
33213327 GGML_ASSERT (dst->ne [0 ] == nc);
33223328 GGML_ASSERT (ggml_nrows (dst) == nr);
33233329
3330+ const int32_t swapped = ggml_get_op_params_i32 (dst, 1 );
3331+
33243332 // rows per thread
33253333 const int dr = (nr + nth - 1 )/nth;
33263334
@@ -3331,7 +3339,8 @@ static void ggml_compute_forward_geglu_f32(
33313339 for (int i1 = ir0; i1 < ir1; i1++) {
33323340 ggml_vec_geglu_f32 (nc,
33333341 (float *) ((char *) dst->data + i1*( dst->nb [1 ])),
3334- (float *) ((char *) src0->data + i1*(src0->nb [1 ])));
3342+ (float *) ((char *) src0->data + i1*(src0->nb [1 ])) + (swapped ? nc : 0 ),
3343+ (float *) ((char *) src0->data + i1*(src0->nb [1 ])) + (swapped ? 0 : nc));
33353344
33363345#ifndef NDEBUG
33373346 for (int k = 0 ; k < nc; k++) {
@@ -3362,6 +3371,8 @@ static void ggml_compute_forward_geglu_f16(
33623371 GGML_ASSERT (dst->ne [0 ] == nc);
33633372 GGML_ASSERT (ggml_nrows (dst) == nr);
33643373
3374+ const int32_t swapped = ggml_get_op_params_i32 (dst, 1 );
3375+
33653376 // rows per thread
33663377 const int dr = (nr + nth - 1 )/nth;
33673378
@@ -3372,7 +3383,8 @@ static void ggml_compute_forward_geglu_f16(
33723383 for (int i1 = ir0; i1 < ir1; i1++) {
33733384 ggml_vec_geglu_f16 (nc,
33743385 (ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb [1 ])),
3375- (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb [1 ])));
3386+ (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb [1 ])) + (swapped ? nc : 0 ),
3387+ (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb [1 ])) + (swapped ? 0 : nc));
33763388
33773389#ifndef NDEBUG
33783390 for (int k = 0 ; k < nc; k++) {
@@ -3428,6 +3440,8 @@ static void ggml_compute_forward_swiglu_f32(
34283440 GGML_ASSERT (dst->ne [0 ] == nc);
34293441 GGML_ASSERT (ggml_nrows (dst) == nr);
34303442
3443+ const int32_t swapped = ggml_get_op_params_i32 (dst, 1 );
3444+
34313445 // rows per thread
34323446 const int dr = (nr + nth - 1 )/nth;
34333447
@@ -3438,7 +3452,8 @@ static void ggml_compute_forward_swiglu_f32(
34383452 for (int i1 = ir0; i1 < ir1; i1++) {
34393453 ggml_vec_swiglu_f32 (nc,
34403454 (float *) ((char *) dst->data + i1*( dst->nb [1 ])),
3441- (float *) ((char *) src0->data + i1*(src0->nb [1 ])));
3455+ (float *) ((char *) src0->data + i1*(src0->nb [1 ])) + (swapped ? nc : 0 ),
3456+ (float *) ((char *) src0->data + i1*(src0->nb [1 ])) + (swapped ? 0 : nc));
34423457
34433458#ifndef NDEBUG
34443459 for (int k = 0 ; k < nc; k++) {
@@ -3469,6 +3484,8 @@ static void ggml_compute_forward_swiglu_f16(
34693484 GGML_ASSERT (dst->ne [0 ] == nc);
34703485 GGML_ASSERT (ggml_nrows (dst) == nr);
34713486
3487+ const int32_t swapped = ggml_get_op_params_i32 (dst, 1 );
3488+
34723489 // rows per thread
34733490 const int dr = (nr + nth - 1 )/nth;
34743491
@@ -3479,7 +3496,8 @@ static void ggml_compute_forward_swiglu_f16(
34793496 for (int i1 = ir0; i1 < ir1; i1++) {
34803497 ggml_vec_swiglu_f16 (nc,
34813498 (ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb [1 ])),
3482- (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb [1 ])));
3499+ (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb [1 ])) + (swapped ? nc : 0 ),
3500+ (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb [1 ])) + (swapped ? 0 : nc));
34833501
34843502#ifndef NDEBUG
34853503 for (int k = 0 ; k < nc; k++) {
0 commit comments