Skip to content

Commit 92943e7

Browse files
authored
relax constraints
1 parent c717198 commit 92943e7

File tree

1 file changed

+18
-18
lines changed

1 file changed

+18
-18
lines changed

ggml/src/ggml-cpu/ops.cpp

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3208,11 +3208,11 @@ static void ggml_compute_forward_reglu_f32(
32083208
const int ith = params->ith;
32093209
const int nth = params->nth;
32103210

3211-
const int nc = dst->ne[0];
3211+
const int nc = src0->ne[0] / 2;
32123212
const int nr = ggml_nrows(src0);
32133213

3214-
GGML_ASSERT(src0->ne[0] / 2 == nc);
3215-
GGML_ASSERT(ggml_nrows(dst) == nr);
3214+
GGML_ASSERT(dst->ne[0] >= nc);
3215+
GGML_ASSERT(ggml_nrows(dst) >= nr);
32163216

32173217
// rows per thread
32183218
const int dr = (nr + nth - 1)/nth;
@@ -3249,11 +3249,11 @@ static void ggml_compute_forward_reglu_f16(
32493249
const int ith = params->ith;
32503250
const int nth = params->nth;
32513251

3252-
const int nc = dst->ne[0];
3252+
const int nc = src0->ne[0] / 2;
32533253
const int nr = ggml_nrows(src0);
32543254

3255-
GGML_ASSERT(src0->ne[0] / 2 == nc);
3256-
GGML_ASSERT(ggml_nrows(dst) == nr);
3255+
GGML_ASSERT(dst->ne[0] >= nc);
3256+
GGML_ASSERT(ggml_nrows(dst) >= nr);
32573257

32583258
// rows per thread
32593259
const int dr = (nr + nth - 1)/nth;
@@ -3315,11 +3315,11 @@ static void ggml_compute_forward_geglu_f32(
33153315
const int ith = params->ith;
33163316
const int nth = params->nth;
33173317

3318-
const int nc = dst->ne[0];
3318+
const int nc = src0->ne[0] / 2;
33193319
const int nr = ggml_nrows(src0);
33203320

3321-
GGML_ASSERT(src0->ne[0] / 2 == nc);
3322-
GGML_ASSERT(ggml_nrows(dst) == nr);
3321+
GGML_ASSERT(dst->ne[0] >= nc);
3322+
GGML_ASSERT(ggml_nrows(dst) >= nr);
33233323

33243324
// rows per thread
33253325
const int dr = (nr + nth - 1)/nth;
@@ -3356,11 +3356,11 @@ static void ggml_compute_forward_geglu_f16(
33563356
const int ith = params->ith;
33573357
const int nth = params->nth;
33583358

3359-
const int nc = dst->ne[0];
3359+
const int nc = src0->ne[0] / 2;
33603360
const int nr = ggml_nrows(src0);
33613361

3362-
GGML_ASSERT(src0->ne[0] / 2 == nc);
3363-
GGML_ASSERT(ggml_nrows(dst) == nr);
3362+
GGML_ASSERT(dst->ne[0] >= nc);
3363+
GGML_ASSERT(ggml_nrows(dst) >= nr);
33643364

33653365
// rows per thread
33663366
const int dr = (nr + nth - 1)/nth;
@@ -3422,11 +3422,11 @@ static void ggml_compute_forward_swiglu_f32(
34223422
const int ith = params->ith;
34233423
const int nth = params->nth;
34243424

3425-
const int nc = dst->ne[0];
3425+
const int nc = src0->ne[0] / 2;
34263426
const int nr = ggml_nrows(src0);
34273427

3428-
GGML_ASSERT(src0->ne[0] / 2 == nc);
3429-
GGML_ASSERT(ggml_nrows(dst) == nr);
3428+
GGML_ASSERT(dst->ne[0] >= nc);
3429+
GGML_ASSERT(ggml_nrows(dst) >= nr);
34303430

34313431
// rows per thread
34323432
const int dr = (nr + nth - 1)/nth;
@@ -3463,11 +3463,11 @@ static void ggml_compute_forward_swiglu_f16(
34633463
const int ith = params->ith;
34643464
const int nth = params->nth;
34653465

3466-
const int nc = dst->ne[0];
3466+
const int nc = src0->ne[0] / 2;
34673467
const int nr = ggml_nrows(src0);
34683468

3469-
GGML_ASSERT(src0->ne[0] / 2 == nc);
3470-
GGML_ASSERT(ggml_nrows(dst) == nr);
3469+
GGML_ASSERT(dst->ne[0] >= nc);
3470+
GGML_ASSERT(ggml_nrows(dst) >= nr);
34713471

34723472
// rows per thread
34733473
const int dr = (nr + nth - 1)/nth;

0 commit comments

Comments
 (0)