Skip to content

Commit 49227de

Browse files
committed
More IKL -> Croco convergence and some factoring
1 parent f4b014c commit 49227de

File tree

1 file changed

+113
-76
lines changed

1 file changed

+113
-76
lines changed

ggml/src/ggml-cpu/ops.cpp

Lines changed: 113 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -70,12 +70,28 @@ static void ggml_compute_forward_dup_f16(
7070
const int ith = params->ith; // thread index
7171
const int nth = params->nth; // number of threads
7272

73-
// parallelize by rows
73+
if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst) && src0->type == dst->type) {
74+
ggml_compute_forward_dup_same_cont(params, dst);
75+
return;
76+
}
77+
78+
/* // parallelize by rows
7479
const int nr = ne01;
7580
// number of rows per thread
7681
const int dr = (nr + nth - 1) / nth;
7782
// row range for this thread
7883
const int ir0 = dr * ith;
84+
const int ir1 = MIN(ir0 + dr, nr); */
85+
86+
// parallelize by rows
87+
int n_packed = ggml_packed_rows(dst->type);
88+
GGML_ASSERT(dst->ne[1] % n_packed == 0);
89+
const int nr = ne01;
90+
// number of rows per thread
91+
const int dr = n_packed*((nr/n_packed + nth - 1) / nth);
92+
// row range for this thread
93+
const int ir0 = dr * ith;
94+
if (ir0 >= nr) return;
7995
const int ir1 = MIN(ir0 + dr, nr);
8096

8197
if (src0->type == dst->type &&
@@ -108,7 +124,7 @@ static void ggml_compute_forward_dup_f16(
108124
for (int i03 = 0; i03 < ne03; i03++) {
109125
for (int i02 = 0; i02 < ne02; i02++) {
110126
id += rs * ir0;
111-
for (int i01 = ir0; i01 < ir1; i01++) {
127+
for (int i01 = ir0; i01 < ir1; i01 += n_packed) {
112128
const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
113129
memcpy(dst_ptr + id, src0_ptr, rs);
114130
id += rs;
@@ -123,7 +139,7 @@ static void ggml_compute_forward_dup_f16(
123139
for (int i03 = 0; i03 < ne03; i03++) {
124140
for (int i02 = 0; i02 < ne02; i02++) {
125141
id += ne00 * ir0;
126-
for (int i01 = ir0; i01 < ir1; i01++) {
142+
for (int i01 = ir0; i01 < ir1; i01 += n_packed) {
127143
const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
128144
for (int i00 = 0; i00 < ne00; i00++) {
129145
dst_ptr[id] = GGML_FP16_TO_FP32(src0_ptr[i00]);
@@ -144,7 +160,7 @@ static void ggml_compute_forward_dup_f16(
144160
for (int i03 = 0; i03 < ne03; i03++) {
145161
for (int i02 = 0; i02 < ne02; i02++) {
146162
id += rs * ir0;
147-
for (int i01 = ir0; i01 < ir1; i01++) {
163+
for (int i01 = ir0; i01 < ir1; i01 += n_packed) {
148164
const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
149165

150166
for (int i00 = 0; i00 < ne00; i00++) {
@@ -170,7 +186,7 @@ static void ggml_compute_forward_dup_f16(
170186
for (int i03 = 0; i03 < ne03; i03++) {
171187
for (int i02 = 0; i02 < ne02; i02++) {
172188
id += ne00 * ir0;
173-
for (int i01 = ir0; i01 < ir1; i01++) {
189+
for (int i01 = ir0; i01 < ir1; i01 += n_packed) {
174190
for (int i00 = 0; i00 < ne00; i00++) {
175191
const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
176192

@@ -188,7 +204,7 @@ static void ggml_compute_forward_dup_f16(
188204
for (int i03 = 0; i03 < ne03; i03++) {
189205
for (int i02 = 0; i02 < ne02; i02++) {
190206
id += ne00 * ir0;
191-
for (int i01 = ir0; i01 < ir1; i01++) {
207+
for (int i01 = ir0; i01 < ir1; i01 += n_packed) {
192208
for (int i00 = 0; i00 < ne00; i00++) {
193209
const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
194210

@@ -334,12 +350,28 @@ static void ggml_compute_forward_dup_bf16(
334350
const int ith = params->ith; // thread index
335351
const int nth = params->nth; // number of threads
336352

337-
// parallelize by rows
353+
if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst) && src0->type == dst->type) {
354+
ggml_compute_forward_dup_same_cont(params, dst);
355+
return;
356+
}
357+
358+
/* // parallelize by rows
338359
const int nr = ne01;
339360
// number of rows per thread
340361
const int dr = (nr + nth - 1) / nth;
341362
// row range for this thread
342363
const int ir0 = dr * ith;
364+
const int ir1 = MIN(ir0 + dr, nr); */
365+
366+
// parallelize by rows
367+
int n_packed = ggml_packed_rows(dst->type);
368+
GGML_ASSERT(dst->ne[1] % n_packed == 0);
369+
const int nr = ne01;
370+
// number of rows per thread
371+
const int dr = n_packed*((nr/n_packed + nth - 1) / nth);
372+
// row range for this thread
373+
const int ir0 = dr * ith;
374+
if (ir0 >= nr) return;
343375
const int ir1 = MIN(ir0 + dr, nr);
344376

345377
if (src0->type == dst->type &&
@@ -372,7 +404,7 @@ static void ggml_compute_forward_dup_bf16(
372404
for (int i03 = 0; i03 < ne03; i03++) {
373405
for (int i02 = 0; i02 < ne02; i02++) {
374406
id += rs * ir0;
375-
for (int i01 = ir0; i01 < ir1; i01++) {
407+
for (int i01 = ir0; i01 < ir1; i01 += n_packed) {
376408
const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
377409
memcpy(dst_ptr + id, src0_ptr, rs);
378410
id += rs;
@@ -387,7 +419,7 @@ static void ggml_compute_forward_dup_bf16(
387419
for (int i03 = 0; i03 < ne03; i03++) {
388420
for (int i02 = 0; i02 < ne02; i02++) {
389421
id += ne00 * ir0;
390-
for (int i01 = ir0; i01 < ir1; i01++) {
422+
for (int i01 = ir0; i01 < ir1; i01 += n_packed) {
391423
const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
392424
for (int i00 = 0; i00 < ne00; i00++) {
393425
dst_ptr[id] = GGML_FP32_TO_FP16(GGML_BF16_TO_FP32(src0_ptr[i00]));
@@ -404,7 +436,7 @@ static void ggml_compute_forward_dup_bf16(
404436
for (int i03 = 0; i03 < ne03; i03++) {
405437
for (int i02 = 0; i02 < ne02; i02++) {
406438
id += ne00 * ir0;
407-
for (int i01 = ir0; i01 < ir1; i01++) {
439+
for (int i01 = ir0; i01 < ir1; i01 += n_packed) {
408440
const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
409441
for (int i00 = 0; i00 < ne00; i00++) {
410442
dst_ptr[id] = GGML_BF16_TO_FP32(src0_ptr[i00]);
@@ -425,7 +457,7 @@ static void ggml_compute_forward_dup_bf16(
425457
for (int i03 = 0; i03 < ne03; i03++) {
426458
for (int i02 = 0; i02 < ne02; i02++) {
427459
id += rs * ir0;
428-
for (int i01 = ir0; i01 < ir1; i01++) {
460+
for (int i01 = ir0; i01 < ir1; i01 += n_packed) {
429461
const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
430462

431463
for (int i00 = 0; i00 < ne00; i00++) {
@@ -451,7 +483,7 @@ static void ggml_compute_forward_dup_bf16(
451483
for (int i03 = 0; i03 < ne03; i03++) {
452484
for (int i02 = 0; i02 < ne02; i02++) {
453485
id += ne00 * ir0;
454-
for (int i01 = ir0; i01 < ir1; i01++) {
486+
for (int i01 = ir0; i01 < ir1; i01 += n_packed) {
455487
for (int i00 = 0; i00 < ne00; i00++) {
456488
const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
457489

@@ -469,7 +501,7 @@ static void ggml_compute_forward_dup_bf16(
469501
for (int i03 = 0; i03 < ne03; i03++) {
470502
for (int i02 = 0; i02 < ne02; i02++) {
471503
id += ne00 * ir0;
472-
for (int i01 = ir0; i01 < ir1; i01++) {
504+
for (int i01 = ir0; i01 < ir1; i01 += n_packed) {
473505
for (int i00 = 0; i00 < ne00; i00++) {
474506
const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
475507

@@ -487,7 +519,7 @@ static void ggml_compute_forward_dup_bf16(
487519
for (int i03 = 0; i03 < ne03; i03++) {
488520
for (int i02 = 0; i02 < ne02; i02++) {
489521
id += ne00 * ir0;
490-
for (int i01 = ir0; i01 < ir1; i01++) {
522+
for (int i01 = ir0; i01 < ir1; i01 += n_packed) {
491523
for (int i00 = 0; i00 < ne00; i00++) {
492524
const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
493525

@@ -685,6 +717,11 @@ static void ggml_compute_forward_dup_f32(
685717
const int ith = params->ith; // thread index
686718
const int nth = params->nth; // number of threads
687719

720+
if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst) && src0->type == dst->type) {
721+
ggml_compute_forward_dup_same_cont(params, dst);
722+
return;
723+
}
724+
688725
// parallelize by rows
689726
int n_packed = ggml_packed_rows(dst->type);
690727
GGML_ASSERT(dst->ne[1] % n_packed == 0);
@@ -693,6 +730,7 @@ static void ggml_compute_forward_dup_f32(
693730
const int dr = n_packed*((nr/n_packed + nth - 1) / nth);
694731
// row range for this thread
695732
const int ir0 = dr * ith;
733+
if (ir0 >= nr) return;
696734
const int ir1 = MIN(ir0 + dr, nr);
697735

698736
if (src0->type == dst->type &&
@@ -724,7 +762,7 @@ static void ggml_compute_forward_dup_f32(
724762
for (int i03 = 0; i03 < ne03; i03++) {
725763
for (int i02 = 0; i02 < ne02; i02++) {
726764
id += rs * ir0;
727-
for (int i01 = ir0; i01 < ir1; i01++) {
765+
for (int i01 = ir0; i01 < ir1; i01 += n_packed) {
728766
const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
729767
memcpy(dst_ptr + id, src0_ptr, rs);
730768
id += rs;
@@ -763,7 +801,7 @@ static void ggml_compute_forward_dup_f32(
763801
for (int i03 = 0; i03 < ne03; i03++) {
764802
for (int i02 = 0; i02 < ne02; i02++) {
765803
id += ne00 * ir0;
766-
for (int i01 = ir0; i01 < ir1; i01++) {
804+
for (int i01 = ir0; i01 < ir1; i01 += n_packed) {
767805
for (int i00 = 0; i00 < ne00; i00++) {
768806
const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
769807

@@ -781,7 +819,7 @@ static void ggml_compute_forward_dup_f32(
781819
for (int i03 = 0; i03 < ne03; i03++) {
782820
for (int i02 = 0; i02 < ne02; i02++) {
783821
id += ne00 * ir0;
784-
for (int i01 = ir0; i01 < ir1; i01++) {
822+
for (int i01 = ir0; i01 < ir1; i01 += n_packed) {
785823
for (int i00 = 0; i00 < ne00; i00++) {
786824
const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
787825

@@ -799,7 +837,7 @@ static void ggml_compute_forward_dup_f32(
799837
for (int i03 = 0; i03 < ne03; i03++) {
800838
for (int i02 = 0; i02 < ne02; i02++) {
801839
id += ne00 * ir0;
802-
for (int i01 = ir0; i01 < ir1; i01++) {
840+
for (int i01 = ir0; i01 < ir1; i01 += n_packed) {
803841
for (int i00 = 0; i00 < ne00; i00++) {
804842
const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
805843

@@ -1008,15 +1046,14 @@ static void ggml_compute_forward_dup_bytes(
10081046
const int nth = params->nth; // number of threads
10091047

10101048
// parallelize by rows
1049+
int n_packed = ggml_packed_rows(dst->type);
1050+
GGML_ASSERT(dst->ne[1] % n_packed == 0);
10111051
const int nr = ne01;
1012-
const int n_packed = ggml_packed_rows(dst->type);
1013-
GGML_ASSERT(nr%n_packed == 0);
1014-
const int nrp = nr/n_packed;
10151052
// number of rows per thread
1016-
const int drp = (nrp + nth - 1) / nth;
1017-
const int dr = drp*n_packed;
1053+
const int dr = n_packed*((nr/n_packed + nth - 1) / nth);
10181054
// row range for this thread
10191055
const int ir0 = dr * ith;
1056+
if (ir0 >= nr) return;
10201057
const int ir1 = MIN(ir0 + dr, nr);
10211058

10221059
if (src0->type == dst->type &&
@@ -1396,6 +1433,59 @@ void ggml_compute_forward_add(
13961433
}
13971434
}
13981435

1436+
static void ggml_compute_forward_multi_add_f32(
1437+
const ggml_compute_params * params,
1438+
ggml_tensor * dst) {
1439+
1440+
const ggml_tensor * src = dst->src[0];
1441+
1442+
GGML_ASSERT(dst->nb[0] == sizeof(float));
1443+
GGML_ASSERT(src->nb[0] == sizeof(float));
1444+
GGML_ASSERT(ggml_are_same_shape(src, dst));
1445+
GGML_ASSERT(dst->ne[2] == 1 && dst->ne[3] == 1);
1446+
1447+
const int n_add = dst->op_params[0];
1448+
GGML_ASSERT(n_add > 0);
1449+
1450+
const int ith = params->ith;
1451+
const int nth = params->nth;
1452+
1453+
const int nr = ggml_nrows(dst);
1454+
1455+
// rows per thread
1456+
const int dr = (nr + nth - 1)/nth;
1457+
1458+
// row range for this thread
1459+
const int ir0 = dr*ith;
1460+
const int ir1 = MIN(ir0 + dr, nr);
1461+
1462+
int64_t ne0 = dst->ne[0];
1463+
1464+
for (int i1 = ir0; i1 < ir1; ++i1) {
1465+
1466+
float * dst_ptr = (float *) ((char *) dst->data + i1*dst->nb[1] );
1467+
const float * data = (const float *) ((const char *)src->data + i1*src->nb[1]);
1468+
memset(dst_ptr, 0, ne0*sizeof(float));
1469+
for (int j = 0; j < n_add; ++j) {
1470+
ggml_vec_add_f32(ne0, dst_ptr, dst_ptr, data + j*ne0);
1471+
}
1472+
}
1473+
}
1474+
1475+
void ggml_compute_forward_multi_add(
1476+
const ggml_compute_params * params,
1477+
ggml_tensor * dst) {
1478+
1479+
switch (dst->type) {
1480+
case GGML_TYPE_F32: {
1481+
ggml_compute_forward_multi_add_f32(params, dst);
1482+
} break;
1483+
default: {
1484+
GGML_ABORT("fatal error");
1485+
}
1486+
}
1487+
}
1488+
13991489
// ggml_compute_forward_add1
14001490

14011491
static void ggml_compute_forward_add1_f32(
@@ -1705,59 +1795,6 @@ static void ggml_compute_forward_add1_bf16_bf16(
17051795
}
17061796
}
17071797

1708-
static void ggml_compute_forward_multi_add_f32(
1709-
const struct ggml_compute_params * params,
1710-
struct ggml_tensor * dst) {
1711-
1712-
struct ggml_tensor * src = dst->src[0];
1713-
1714-
GGML_ASSERT(dst->nb[0] == sizeof(float));
1715-
GGML_ASSERT(src->nb[0] == sizeof(float));
1716-
GGML_ASSERT(ggml_are_same_shape(src, dst));
1717-
GGML_ASSERT(dst->ne[2] == 1 && dst->ne[3] == 1);
1718-
1719-
const int n_add = dst->op_params[0];
1720-
GGML_ASSERT(n_add > 0);
1721-
1722-
const int ith = params->ith;
1723-
const int nth = params->nth;
1724-
1725-
const int nr = ggml_nrows(dst);
1726-
1727-
// rows per thread
1728-
const int dr = (nr + nth - 1)/nth;
1729-
1730-
// row range for this thread
1731-
const int ir0 = dr*ith;
1732-
const int ir1 = MIN(ir0 + dr, nr);
1733-
1734-
int64_t ne0 = dst->ne[0];
1735-
1736-
for (int i1 = ir0; i1 < ir1; ++i1) {
1737-
1738-
float * dst_ptr = (float *) ((char *) dst->data + i1*dst->nb[1] );
1739-
const float * data = (const float *) ((const char *)src->data + i1*src->nb[1]);
1740-
memset(dst_ptr, 0, ne0*sizeof(float));
1741-
for (int j = 0; j < n_add; ++j) {
1742-
ggml_vec_add_f32(ne0, dst_ptr, dst_ptr, data + j*ne0);
1743-
}
1744-
}
1745-
}
1746-
1747-
void ggml_compute_forward_multi_add(
1748-
const struct ggml_compute_params * params,
1749-
struct ggml_tensor * dst) {
1750-
1751-
switch (dst->type) {
1752-
case GGML_TYPE_F32: {
1753-
ggml_compute_forward_multi_add_f32(params, dst);
1754-
} break;
1755-
default: {
1756-
GGML_ABORT("fatal error");
1757-
}
1758-
}
1759-
}
1760-
17611798
void ggml_compute_forward_add1(
17621799
const ggml_compute_params * params,
17631800
ggml_tensor * dst) {

0 commit comments

Comments
 (0)