@@ -3208,11 +3208,11 @@ static void ggml_compute_forward_reglu_f32(
32083208 const int ith = params->ith ;
32093209 const int nth = params->nth ;
32103210
3211- const int nc = dst ->ne [0 ];
3211+ const int nc = src0 ->ne [0 ] / 2 ;
32123212 const int nr = ggml_nrows (src0);
32133213
3214- GGML_ASSERT (src0 ->ne [0 ] / 2 = = nc);
3215- GGML_ASSERT (ggml_nrows (dst) = = nr);
3214+ GGML_ASSERT (dst ->ne [0 ] > = nc);
3215+ GGML_ASSERT (ggml_nrows (dst) > = nr);
32163216
32173217 // rows per thread
32183218 const int dr = (nr + nth - 1 )/nth;
@@ -3249,11 +3249,11 @@ static void ggml_compute_forward_reglu_f16(
32493249 const int ith = params->ith ;
32503250 const int nth = params->nth ;
32513251
3252- const int nc = dst ->ne [0 ];
3252+ const int nc = src0 ->ne [0 ] / 2 ;
32533253 const int nr = ggml_nrows (src0);
32543254
3255- GGML_ASSERT (src0 ->ne [0 ] / 2 = = nc);
3256- GGML_ASSERT (ggml_nrows (dst) = = nr);
3255+ GGML_ASSERT (dst ->ne [0 ] > = nc);
3256+ GGML_ASSERT (ggml_nrows (dst) > = nr);
32573257
32583258 // rows per thread
32593259 const int dr = (nr + nth - 1 )/nth;
@@ -3315,11 +3315,11 @@ static void ggml_compute_forward_geglu_f32(
33153315 const int ith = params->ith ;
33163316 const int nth = params->nth ;
33173317
3318- const int nc = dst ->ne [0 ];
3318+ const int nc = src0 ->ne [0 ] / 2 ;
33193319 const int nr = ggml_nrows (src0);
33203320
3321- GGML_ASSERT (src0 ->ne [0 ] / 2 = = nc);
3322- GGML_ASSERT (ggml_nrows (dst) = = nr);
3321+ GGML_ASSERT (dst ->ne [0 ] > = nc);
3322+ GGML_ASSERT (ggml_nrows (dst) > = nr);
33233323
33243324 // rows per thread
33253325 const int dr = (nr + nth - 1 )/nth;
@@ -3356,11 +3356,11 @@ static void ggml_compute_forward_geglu_f16(
33563356 const int ith = params->ith ;
33573357 const int nth = params->nth ;
33583358
3359- const int nc = dst ->ne [0 ];
3359+ const int nc = src0 ->ne [0 ] / 2 ;
33603360 const int nr = ggml_nrows (src0);
33613361
3362- GGML_ASSERT (src0 ->ne [0 ] / 2 = = nc);
3363- GGML_ASSERT (ggml_nrows (dst) = = nr);
3362+ GGML_ASSERT (dst ->ne [0 ] > = nc);
3363+ GGML_ASSERT (ggml_nrows (dst) > = nr);
33643364
33653365 // rows per thread
33663366 const int dr = (nr + nth - 1 )/nth;
@@ -3422,11 +3422,11 @@ static void ggml_compute_forward_swiglu_f32(
34223422 const int ith = params->ith ;
34233423 const int nth = params->nth ;
34243424
3425- const int nc = dst ->ne [0 ];
3425+ const int nc = src0 ->ne [0 ] / 2 ;
34263426 const int nr = ggml_nrows (src0);
34273427
3428- GGML_ASSERT (src0 ->ne [0 ] / 2 = = nc);
3429- GGML_ASSERT (ggml_nrows (dst) = = nr);
3428+ GGML_ASSERT (dst ->ne [0 ] > = nc);
3429+ GGML_ASSERT (ggml_nrows (dst) > = nr);
34303430
34313431 // rows per thread
34323432 const int dr = (nr + nth - 1 )/nth;
@@ -3463,11 +3463,11 @@ static void ggml_compute_forward_swiglu_f16(
34633463 const int ith = params->ith ;
34643464 const int nth = params->nth ;
34653465
3466- const int nc = dst ->ne [0 ];
3466+ const int nc = src0 ->ne [0 ] / 2 ;
34673467 const int nr = ggml_nrows (src0);
34683468
3469- GGML_ASSERT (src0 ->ne [0 ] / 2 = = nc);
3470- GGML_ASSERT (ggml_nrows (dst) = = nr);
3469+ GGML_ASSERT (dst ->ne [0 ] > = nc);
3470+ GGML_ASSERT (ggml_nrows (dst) > = nr);
34713471
34723472 // rows per thread
34733473 const int dr = (nr + nth - 1 )/nth;
0 commit comments