@@ -70,12 +70,28 @@ static void ggml_compute_forward_dup_f16(
7070 const int ith = params->ith ; // thread index
7171 const int nth = params->nth ; // number of threads
7272
73- // parallelize by rows
73+ if (ggml_is_contiguous (src0) && ggml_is_contiguous (dst) && src0->type == dst->type ) {
74+ ggml_compute_forward_dup_same_cont (params, dst);
75+ return ;
76+ }
77+
78+ /* // parallelize by rows
7479 const int nr = ne01;
7580 // number of rows per thread
7681 const int dr = (nr + nth - 1) / nth;
7782 // row range for this thread
7883 const int ir0 = dr * ith;
84+ const int ir1 = MIN(ir0 + dr, nr); */
85+
86+ // parallelize by rows
87+ int n_packed = ggml_packed_rows (dst->type );
88+ GGML_ASSERT (dst->ne [1 ] % n_packed == 0 );
89+ const int nr = ne01;
90+ // number of rows per thread
91+ const int dr = n_packed*((nr/n_packed + nth - 1 ) / nth);
92+ // row range for this thread
93+ const int ir0 = dr * ith;
94+ if (ir0 >= nr) return ;
7995 const int ir1 = MIN (ir0 + dr, nr);
8096
8197 if (src0->type == dst->type &&
@@ -108,7 +124,7 @@ static void ggml_compute_forward_dup_f16(
108124 for (int i03 = 0 ; i03 < ne03; i03++) {
109125 for (int i02 = 0 ; i02 < ne02; i02++) {
110126 id += rs * ir0;
111- for (int i01 = ir0; i01 < ir1; i01++ ) {
127+ for (int i01 = ir0; i01 < ir1; i01 += n_packed ) {
112128 const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
113129 memcpy (dst_ptr + id, src0_ptr, rs);
114130 id += rs;
@@ -123,7 +139,7 @@ static void ggml_compute_forward_dup_f16(
123139 for (int i03 = 0 ; i03 < ne03; i03++) {
124140 for (int i02 = 0 ; i02 < ne02; i02++) {
125141 id += ne00 * ir0;
126- for (int i01 = ir0; i01 < ir1; i01++ ) {
142+ for (int i01 = ir0; i01 < ir1; i01 += n_packed ) {
127143 const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
128144 for (int i00 = 0 ; i00 < ne00; i00++) {
129145 dst_ptr[id] = GGML_FP16_TO_FP32 (src0_ptr[i00]);
@@ -144,7 +160,7 @@ static void ggml_compute_forward_dup_f16(
144160 for (int i03 = 0 ; i03 < ne03; i03++) {
145161 for (int i02 = 0 ; i02 < ne02; i02++) {
146162 id += rs * ir0;
147- for (int i01 = ir0; i01 < ir1; i01++ ) {
163+ for (int i01 = ir0; i01 < ir1; i01 += n_packed ) {
148164 const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
149165
150166 for (int i00 = 0 ; i00 < ne00; i00++) {
@@ -170,7 +186,7 @@ static void ggml_compute_forward_dup_f16(
170186 for (int i03 = 0 ; i03 < ne03; i03++) {
171187 for (int i02 = 0 ; i02 < ne02; i02++) {
172188 id += ne00 * ir0;
173- for (int i01 = ir0; i01 < ir1; i01++ ) {
189+ for (int i01 = ir0; i01 < ir1; i01 += n_packed ) {
174190 for (int i00 = 0 ; i00 < ne00; i00++) {
175191 const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
176192
@@ -188,7 +204,7 @@ static void ggml_compute_forward_dup_f16(
188204 for (int i03 = 0 ; i03 < ne03; i03++) {
189205 for (int i02 = 0 ; i02 < ne02; i02++) {
190206 id += ne00 * ir0;
191- for (int i01 = ir0; i01 < ir1; i01++ ) {
207+ for (int i01 = ir0; i01 < ir1; i01 += n_packed ) {
192208 for (int i00 = 0 ; i00 < ne00; i00++) {
193209 const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
194210
@@ -334,12 +350,28 @@ static void ggml_compute_forward_dup_bf16(
334350 const int ith = params->ith ; // thread index
335351 const int nth = params->nth ; // number of threads
336352
337- // parallelize by rows
353+ if (ggml_is_contiguous (src0) && ggml_is_contiguous (dst) && src0->type == dst->type ) {
354+ ggml_compute_forward_dup_same_cont (params, dst);
355+ return ;
356+ }
357+
358+ /* // parallelize by rows
338359 const int nr = ne01;
339360 // number of rows per thread
340361 const int dr = (nr + nth - 1) / nth;
341362 // row range for this thread
342363 const int ir0 = dr * ith;
364+ const int ir1 = MIN(ir0 + dr, nr); */
365+
366+ // parallelize by rows
367+ int n_packed = ggml_packed_rows (dst->type );
368+ GGML_ASSERT (dst->ne [1 ] % n_packed == 0 );
369+ const int nr = ne01;
370+ // number of rows per thread
371+ const int dr = n_packed*((nr/n_packed + nth - 1 ) / nth);
372+ // row range for this thread
373+ const int ir0 = dr * ith;
374+ if (ir0 >= nr) return ;
343375 const int ir1 = MIN (ir0 + dr, nr);
344376
345377 if (src0->type == dst->type &&
@@ -372,7 +404,7 @@ static void ggml_compute_forward_dup_bf16(
372404 for (int i03 = 0 ; i03 < ne03; i03++) {
373405 for (int i02 = 0 ; i02 < ne02; i02++) {
374406 id += rs * ir0;
375- for (int i01 = ir0; i01 < ir1; i01++ ) {
407+ for (int i01 = ir0; i01 < ir1; i01 += n_packed ) {
376408 const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
377409 memcpy (dst_ptr + id, src0_ptr, rs);
378410 id += rs;
@@ -387,7 +419,7 @@ static void ggml_compute_forward_dup_bf16(
387419 for (int i03 = 0 ; i03 < ne03; i03++) {
388420 for (int i02 = 0 ; i02 < ne02; i02++) {
389421 id += ne00 * ir0;
390- for (int i01 = ir0; i01 < ir1; i01++ ) {
422+ for (int i01 = ir0; i01 < ir1; i01 += n_packed ) {
391423 const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
392424 for (int i00 = 0 ; i00 < ne00; i00++) {
393425 dst_ptr[id] = GGML_FP32_TO_FP16 (GGML_BF16_TO_FP32 (src0_ptr[i00]));
@@ -404,7 +436,7 @@ static void ggml_compute_forward_dup_bf16(
404436 for (int i03 = 0 ; i03 < ne03; i03++) {
405437 for (int i02 = 0 ; i02 < ne02; i02++) {
406438 id += ne00 * ir0;
407- for (int i01 = ir0; i01 < ir1; i01++ ) {
439+ for (int i01 = ir0; i01 < ir1; i01 += n_packed ) {
408440 const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
409441 for (int i00 = 0 ; i00 < ne00; i00++) {
410442 dst_ptr[id] = GGML_BF16_TO_FP32 (src0_ptr[i00]);
@@ -425,7 +457,7 @@ static void ggml_compute_forward_dup_bf16(
425457 for (int i03 = 0 ; i03 < ne03; i03++) {
426458 for (int i02 = 0 ; i02 < ne02; i02++) {
427459 id += rs * ir0;
428- for (int i01 = ir0; i01 < ir1; i01++ ) {
460+ for (int i01 = ir0; i01 < ir1; i01 += n_packed ) {
429461 const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
430462
431463 for (int i00 = 0 ; i00 < ne00; i00++) {
@@ -451,7 +483,7 @@ static void ggml_compute_forward_dup_bf16(
451483 for (int i03 = 0 ; i03 < ne03; i03++) {
452484 for (int i02 = 0 ; i02 < ne02; i02++) {
453485 id += ne00 * ir0;
454- for (int i01 = ir0; i01 < ir1; i01++ ) {
486+ for (int i01 = ir0; i01 < ir1; i01 += n_packed ) {
455487 for (int i00 = 0 ; i00 < ne00; i00++) {
456488 const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
457489
@@ -469,7 +501,7 @@ static void ggml_compute_forward_dup_bf16(
469501 for (int i03 = 0 ; i03 < ne03; i03++) {
470502 for (int i02 = 0 ; i02 < ne02; i02++) {
471503 id += ne00 * ir0;
472- for (int i01 = ir0; i01 < ir1; i01++ ) {
504+ for (int i01 = ir0; i01 < ir1; i01 += n_packed ) {
473505 for (int i00 = 0 ; i00 < ne00; i00++) {
474506 const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
475507
@@ -487,7 +519,7 @@ static void ggml_compute_forward_dup_bf16(
487519 for (int i03 = 0 ; i03 < ne03; i03++) {
488520 for (int i02 = 0 ; i02 < ne02; i02++) {
489521 id += ne00 * ir0;
490- for (int i01 = ir0; i01 < ir1; i01++ ) {
522+ for (int i01 = ir0; i01 < ir1; i01 += n_packed ) {
491523 for (int i00 = 0 ; i00 < ne00; i00++) {
492524 const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
493525
@@ -685,6 +717,11 @@ static void ggml_compute_forward_dup_f32(
685717 const int ith = params->ith ; // thread index
686718 const int nth = params->nth ; // number of threads
687719
720+ if (ggml_is_contiguous (src0) && ggml_is_contiguous (dst) && src0->type == dst->type ) {
721+ ggml_compute_forward_dup_same_cont (params, dst);
722+ return ;
723+ }
724+
688725 // parallelize by rows
689726 int n_packed = ggml_packed_rows (dst->type );
690727 GGML_ASSERT (dst->ne [1 ] % n_packed == 0 );
@@ -693,6 +730,7 @@ static void ggml_compute_forward_dup_f32(
693730 const int dr = n_packed*((nr/n_packed + nth - 1 ) / nth);
694731 // row range for this thread
695732 const int ir0 = dr * ith;
733+ if (ir0 >= nr) return ;
696734 const int ir1 = MIN (ir0 + dr, nr);
697735
698736 if (src0->type == dst->type &&
@@ -724,7 +762,7 @@ static void ggml_compute_forward_dup_f32(
724762 for (int i03 = 0 ; i03 < ne03; i03++) {
725763 for (int i02 = 0 ; i02 < ne02; i02++) {
726764 id += rs * ir0;
727- for (int i01 = ir0; i01 < ir1; i01++ ) {
765+ for (int i01 = ir0; i01 < ir1; i01 += n_packed ) {
728766 const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
729767 memcpy (dst_ptr + id, src0_ptr, rs);
730768 id += rs;
@@ -763,7 +801,7 @@ static void ggml_compute_forward_dup_f32(
763801 for (int i03 = 0 ; i03 < ne03; i03++) {
764802 for (int i02 = 0 ; i02 < ne02; i02++) {
765803 id += ne00 * ir0;
766- for (int i01 = ir0; i01 < ir1; i01++ ) {
804+ for (int i01 = ir0; i01 < ir1; i01 += n_packed ) {
767805 for (int i00 = 0 ; i00 < ne00; i00++) {
768806 const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
769807
@@ -781,7 +819,7 @@ static void ggml_compute_forward_dup_f32(
781819 for (int i03 = 0 ; i03 < ne03; i03++) {
782820 for (int i02 = 0 ; i02 < ne02; i02++) {
783821 id += ne00 * ir0;
784- for (int i01 = ir0; i01 < ir1; i01++ ) {
822+ for (int i01 = ir0; i01 < ir1; i01 += n_packed ) {
785823 for (int i00 = 0 ; i00 < ne00; i00++) {
786824 const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
787825
@@ -799,7 +837,7 @@ static void ggml_compute_forward_dup_f32(
799837 for (int i03 = 0 ; i03 < ne03; i03++) {
800838 for (int i02 = 0 ; i02 < ne02; i02++) {
801839 id += ne00 * ir0;
802- for (int i01 = ir0; i01 < ir1; i01++ ) {
840+ for (int i01 = ir0; i01 < ir1; i01 += n_packed ) {
803841 for (int i00 = 0 ; i00 < ne00; i00++) {
804842 const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
805843
@@ -1008,15 +1046,14 @@ static void ggml_compute_forward_dup_bytes(
10081046 const int nth = params->nth ; // number of threads
10091047
10101048 // parallelize by rows
1049+ int n_packed = ggml_packed_rows (dst->type );
1050+ GGML_ASSERT (dst->ne [1 ] % n_packed == 0 );
10111051 const int nr = ne01;
1012- const int n_packed = ggml_packed_rows (dst->type );
1013- GGML_ASSERT (nr%n_packed == 0 );
1014- const int nrp = nr/n_packed;
10151052 // number of rows per thread
1016- const int drp = (nrp + nth - 1 ) / nth;
1017- const int dr = drp*n_packed;
1053+ const int dr = n_packed*((nr/n_packed + nth - 1 ) / nth);
10181054 // row range for this thread
10191055 const int ir0 = dr * ith;
1056+ if (ir0 >= nr) return ;
10201057 const int ir1 = MIN (ir0 + dr, nr);
10211058
10221059 if (src0->type == dst->type &&
@@ -1396,6 +1433,59 @@ void ggml_compute_forward_add(
13961433 }
13971434}
13981435
1436+ static void ggml_compute_forward_multi_add_f32 (
1437+ const ggml_compute_params * params,
1438+ ggml_tensor * dst) {
1439+
1440+ const ggml_tensor * src = dst->src [0 ];
1441+
1442+ GGML_ASSERT (dst->nb [0 ] == sizeof (float ));
1443+ GGML_ASSERT (src->nb [0 ] == sizeof (float ));
1444+ GGML_ASSERT (ggml_are_same_shape (src, dst));
1445+ GGML_ASSERT (dst->ne [2 ] == 1 && dst->ne [3 ] == 1 );
1446+
1447+ const int n_add = dst->op_params [0 ];
1448+ GGML_ASSERT (n_add > 0 );
1449+
1450+ const int ith = params->ith ;
1451+ const int nth = params->nth ;
1452+
1453+ const int nr = ggml_nrows (dst);
1454+
1455+ // rows per thread
1456+ const int dr = (nr + nth - 1 )/nth;
1457+
1458+ // row range for this thread
1459+ const int ir0 = dr*ith;
1460+ const int ir1 = MIN (ir0 + dr, nr);
1461+
1462+ int64_t ne0 = dst->ne [0 ];
1463+
1464+ for (int i1 = ir0; i1 < ir1; ++i1) {
1465+
1466+ float * dst_ptr = (float *) ((char *) dst->data + i1*dst->nb [1 ] );
1467+ const float * data = (const float *) ((const char *)src->data + i1*src->nb [1 ]);
1468+ memset (dst_ptr, 0 , ne0*sizeof (float ));
1469+ for (int j = 0 ; j < n_add; ++j) {
1470+ ggml_vec_add_f32 (ne0, dst_ptr, dst_ptr, data + j*ne0);
1471+ }
1472+ }
1473+ }
1474+
1475+ void ggml_compute_forward_multi_add (
1476+ const ggml_compute_params * params,
1477+ ggml_tensor * dst) {
1478+
1479+ switch (dst->type ) {
1480+ case GGML_TYPE_F32: {
1481+ ggml_compute_forward_multi_add_f32 (params, dst);
1482+ } break ;
1483+ default : {
1484+ GGML_ABORT (" fatal error" );
1485+ }
1486+ }
1487+ }
1488+
13991489// ggml_compute_forward_add1
14001490
14011491static void ggml_compute_forward_add1_f32 (
@@ -1705,59 +1795,6 @@ static void ggml_compute_forward_add1_bf16_bf16(
17051795 }
17061796}
17071797
1708- static void ggml_compute_forward_multi_add_f32 (
1709- const struct ggml_compute_params * params,
1710- struct ggml_tensor * dst) {
1711-
1712- struct ggml_tensor * src = dst->src [0 ];
1713-
1714- GGML_ASSERT (dst->nb [0 ] == sizeof (float ));
1715- GGML_ASSERT (src->nb [0 ] == sizeof (float ));
1716- GGML_ASSERT (ggml_are_same_shape (src, dst));
1717- GGML_ASSERT (dst->ne [2 ] == 1 && dst->ne [3 ] == 1 );
1718-
1719- const int n_add = dst->op_params [0 ];
1720- GGML_ASSERT (n_add > 0 );
1721-
1722- const int ith = params->ith ;
1723- const int nth = params->nth ;
1724-
1725- const int nr = ggml_nrows (dst);
1726-
1727- // rows per thread
1728- const int dr = (nr + nth - 1 )/nth;
1729-
1730- // row range for this thread
1731- const int ir0 = dr*ith;
1732- const int ir1 = MIN (ir0 + dr, nr);
1733-
1734- int64_t ne0 = dst->ne [0 ];
1735-
1736- for (int i1 = ir0; i1 < ir1; ++i1) {
1737-
1738- float * dst_ptr = (float *) ((char *) dst->data + i1*dst->nb [1 ] );
1739- const float * data = (const float *) ((const char *)src->data + i1*src->nb [1 ]);
1740- memset (dst_ptr, 0 , ne0*sizeof (float ));
1741- for (int j = 0 ; j < n_add; ++j) {
1742- ggml_vec_add_f32 (ne0, dst_ptr, dst_ptr, data + j*ne0);
1743- }
1744- }
1745- }
1746-
1747- void ggml_compute_forward_multi_add (
1748- const struct ggml_compute_params * params,
1749- struct ggml_tensor * dst) {
1750-
1751- switch (dst->type ) {
1752- case GGML_TYPE_F32: {
1753- ggml_compute_forward_multi_add_f32 (params, dst);
1754- } break ;
1755- default : {
1756- GGML_ABORT (" fatal error" );
1757- }
1758- }
1759- }
1760-
17611798void ggml_compute_forward_add1 (
17621799 const ggml_compute_params * params,
17631800 ggml_tensor * dst) {
0 commit comments