@@ -3487,10 +3487,9 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
34873487#if defined(__ARM_FEATURE_MATMUL_INT8)
34883488 if (nrc == 2) {
34893489 const block_q4_0 * restrict vx0 = vx;
3490- const block_q4_0 * restrict vx1 = vx + bx;
3491-
3490+ const block_q4_0 * restrict vx1 = (const block_q4_0 *) ((const uint8_t*)vx + bx);
34923491 const block_q8_0 * restrict vy0 = vy;
3493- const block_q8_0 * restrict vy1 = vy + by;
3492+ const block_q8_0 * restrict vy1 = (const block_q8_0 *) ((const uint8_t*) vy + by) ;
34943493
34953494 float32x4_t sumv0 = vdupq_n_f32(0.0f);
34963495
@@ -3524,10 +3523,12 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
35243523 const int8x16_t y1_l = vld1q_s8(b_y1->qs);
35253524 const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16);
35263525
3527- float32x4_t scale = {GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y0->d),
3528- GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y1->d),
3529- GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y0->d),
3530- GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y1->d)};
3526+ float32_t _scale[4] = { GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y0->d),
3527+ GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y1->d),
3528+ GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y0->d),
3529+ GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y1->d)};
3530+
3531+ float32x4_t scale = vld1q_f32(_scale);
35313532
35323533 int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
35333534 int8x16_t l1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
@@ -3894,9 +3895,9 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r
38943895#if defined(__ARM_FEATURE_MATMUL_INT8)
38953896 if (nrc == 2) {
38963897 const block_q4_1 * restrict vx0 = vx;
3897- const block_q4_1 * restrict vx1 = vx + bx;
3898+ const block_q4_1 * restrict vx1 = (const block_q4_1 *) ((const uint8_t*) vx + bx) ;
38983899 const block_q8_1 * restrict vy0 = vy;
3899- const block_q8_1 * restrict vy1 = vy + by;
3900+ const block_q8_1 * restrict vy1 = (const block_q8_1 *) ((const uint8_t*) vy + by) ;
39003901
39013902 float32x4_t sumv0 = vdupq_n_f32(0.0f);
39023903 float32x4_t summs0 = vdupq_n_f32(0.0f);
@@ -3907,11 +3908,11 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r
39073908 const block_q8_1 * restrict b_y0 = &vy0[i];
39083909 const block_q8_1 * restrict b_y1 = &vy1[i];
39093910
3910- float32x4_t summs_t = {GGML_FP16_TO_FP32(b_x0->m) * GGML_FP16_TO_FP32(b_y0->s),
3911- GGML_FP16_TO_FP32(b_x1->m) * GGML_FP16_TO_FP32(b_y0->s),
3912- GGML_FP16_TO_FP32(b_x0->m) * GGML_FP16_TO_FP32(b_y1->s),
3913- GGML_FP16_TO_FP32(b_x1->m) * GGML_FP16_TO_FP32(b_y1->s)};
3914- summs0 += summs_t;
3911+ float32_t summs_t[4] = {GGML_FP16_TO_FP32(b_x0->m) * GGML_FP16_TO_FP32(b_y0->s),
3912+ GGML_FP16_TO_FP32(b_x1->m) * GGML_FP16_TO_FP32(b_y0->s),
3913+ GGML_FP16_TO_FP32(b_x0->m) * GGML_FP16_TO_FP32(b_y1->s),
3914+ GGML_FP16_TO_FP32(b_x1->m) * GGML_FP16_TO_FP32(b_y1->s)};
3915+ summs0 = vaddq_f32(summs0, vld1q_f32( summs_t)) ;
39153916
39163917 const uint8x16_t m4b = vdupq_n_u8(0x0F);
39173918
@@ -3931,10 +3932,11 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r
39313932 const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16);
39323933
39333934 // mmla into int32x4_t
3934- float32x4_t scale = {GGML_FP16_TO_FP32(b_x0->d)*b_y0->d,
3935- GGML_FP16_TO_FP32(b_x0->d)*b_y1->d,
3936- GGML_FP16_TO_FP32(b_x1->d)*b_y0->d,
3937- GGML_FP16_TO_FP32(b_x1->d)*b_y1->d};
3935+ float32_t _scale[4] = {GGML_FP16_TO_FP32(b_x0->d)*b_y0->d,
3936+ GGML_FP16_TO_FP32(b_x0->d)*b_y1->d,
3937+ GGML_FP16_TO_FP32(b_x1->d)*b_y0->d,
3938+ GGML_FP16_TO_FP32(b_x1->d)*b_y1->d};
3939+ float32x4_t scale = vld1q_f32(_scale);
39383940
39393941 int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
39403942 int8x16_t l1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
@@ -3953,7 +3955,7 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r
39533955
39543956 float32x4_t sumv1 = vextq_f32(sumv0, sumv0, 2);
39553957 float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1);
3956- sumv2 = sumv2 + summs0;
3958+ sumv2 = vaddq_f32( sumv2, summs0) ;
39573959
39583960 vst1_f32(s, vget_low_f32(sumv2));
39593961 vst1_f32(s + bs, vget_high_f32(sumv2));
@@ -4836,35 +4838,36 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r
48364838
48374839#if defined(__ARM_FEATURE_MATMUL_INT8)
48384840 if (nrc == 2) {
4839- const block_q8_0 * restrict vx0 = vx;
4840- const block_q8_0 * restrict vx1 = vx + bx;
4841+ const block_q4_0 * restrict vx0 = vx;
4842+ const block_q4_0 * restrict vx1 = (const block_q4_0 *) ((const uint8_t*) vx + bx) ;
48414843 const block_q8_0 * restrict vy0 = vy;
4842- const block_q8_0 * restrict vy1 = vy + by;
4844+ const block_q8_0 * restrict vy1 = (const block_q8_0 *) ((const uint8_t*) vy + by) ;
48434845
48444846 float32x4_t sumv0 = vdupq_n_f32(0.0f);
48454847
48464848 for (int i = 0; i < nb; i++) {
4847- const block_q8_0 * restrict b_x0 = &vx0[i];
4849+ const block_q4_0 * restrict b_x0 = &vx0[i];
48484850 const block_q8_0 * restrict b_y0 = &vy0[i];
48494851
4850- const block_q8_0 * restrict b_x1 = &vx1[i];
4852+ const block_q4_0 * restrict b_x1 = &vx1[i];
48514853 const block_q8_0 * restrict b_y1 = &vy1[i];
48524854
4853- const int8x16_t x0_l = vld1q_s8(b_x0->qs);
4854- const int8x16_t x0_h = vld1q_s8(b_x0->qs + 16);
4855- const int8x16_t x1_l = vld1q_s8(b_x1->qs);
4856- const int8x16_t x1_h = vld1q_s8(b_x1->qs + 16);
4855+ const int8x16_t x0_l = vld1q_s8((const int8_t*) b_x0->qs);
4856+ const int8x16_t x0_h = vld1q_s8((const int8_t*) b_x0->qs + 16);
4857+ const int8x16_t x1_l = vld1q_s8((const int8_t*) b_x1->qs);
4858+ const int8x16_t x1_h = vld1q_s8((const int8_t*) b_x1->qs + 16);
48574859
48584860 // load y
48594861 const int8x16_t y0_l = vld1q_s8(b_y0->qs);
48604862 const int8x16_t y0_h = vld1q_s8(b_y0->qs + 16);
48614863 const int8x16_t y1_l = vld1q_s8(b_y1->qs);
48624864 const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16);
48634865
4864- float32x4_t scale = {GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y0->d),
4865- GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y1->d),
4866- GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y0->d),
4867- GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y1->d)};
4866+ float32_t _scale[4] = {GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y0->d),
4867+ GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y1->d),
4868+ GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y0->d),
4869+ GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y1->d)};
4870+ float32x4_t scale = vld1q_f32(_scale);
48684871
48694872 int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
48704873 int8x16_t l1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
0 commit comments