@@ -1791,11 +1791,12 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
17911791 const int8x16_t y1_l = vld1q_s8 (b_y1 -> qs );
17921792 const int8x16_t y1_h = vld1q_s8 (b_y1 -> qs + 16 );
17931793
1794- float32_t _scale [4 ] = { GGML_FP16_TO_FP32 (b_x0 -> d )* GGML_FP16_TO_FP32 (b_y0 -> d ),
1795- GGML_FP16_TO_FP32 (b_x0 -> d )* GGML_FP16_TO_FP32 (b_y1 -> d ),
1796- GGML_FP16_TO_FP32 (b_x1 -> d )* GGML_FP16_TO_FP32 (b_y0 -> d ),
1797- GGML_FP16_TO_FP32 (b_x1 -> d )* GGML_FP16_TO_FP32 (b_y1 -> d )};
1798-
1794+ float32_t _scale [4 ] = {
1795+ GGML_FP16_TO_FP32 (b_x0 -> d )* GGML_FP16_TO_FP32 (b_y0 -> d ),
1796+ GGML_FP16_TO_FP32 (b_x0 -> d )* GGML_FP16_TO_FP32 (b_y1 -> d ),
1797+ GGML_FP16_TO_FP32 (b_x1 -> d )* GGML_FP16_TO_FP32 (b_y0 -> d ),
1798+ GGML_FP16_TO_FP32 (b_x1 -> d )* GGML_FP16_TO_FP32 (b_y1 -> d )
1799+ };
17991800 float32x4_t scale = vld1q_f32 (_scale );
18001801
18011802 int8x16_t l0 = vreinterpretq_s8_s64 (vzip1q_s64 (vreinterpretq_s64_s8 (x0_l ), vreinterpretq_s64_s8 (x1_l )));
@@ -1811,7 +1812,7 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
18111812 int8x16_t r3 = vreinterpretq_s8_s64 (vzip2q_s64 (vreinterpretq_s64_s8 (y0_h ), vreinterpretq_s64_s8 (y1_h )));
18121813
18131814 sumv0 = vmlaq_f32 (sumv0 ,(vcvtq_f32_s32 (vmmlaq_s32 ((vmmlaq_s32 ((vmmlaq_s32 ((vmmlaq_s32 (vdupq_n_s32 (0 ), l0 , r0 )),
1814- l1 , r1 )), l2 , r2 )), l3 , r3 ))), scale );
1815+ l1 , r1 )), l2 , r2 )), l3 , r3 ))), scale );
18151816 }
18161817
18171818 float32x4_t sumv1 = vextq_f32 (sumv0 , sumv0 , 2 );
@@ -2347,10 +2348,12 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r
23472348 const block_q8_1 * restrict b_y0 = & vy0 [i ];
23482349 const block_q8_1 * restrict b_y1 = & vy1 [i ];
23492350
2350- float32_t summs_t [4 ] = {GGML_FP16_TO_FP32 (b_x0 -> m ) * GGML_FP16_TO_FP32 (b_y0 -> s ),
2351- GGML_FP16_TO_FP32 (b_x1 -> m ) * GGML_FP16_TO_FP32 (b_y0 -> s ),
2352- GGML_FP16_TO_FP32 (b_x0 -> m ) * GGML_FP16_TO_FP32 (b_y1 -> s ),
2353- GGML_FP16_TO_FP32 (b_x1 -> m ) * GGML_FP16_TO_FP32 (b_y1 -> s )};
2351+ float32_t summs_t [4 ] = {
2352+ GGML_FP16_TO_FP32 (b_x0 -> m ) * GGML_FP16_TO_FP32 (b_y0 -> s ),
2353+ GGML_FP16_TO_FP32 (b_x1 -> m ) * GGML_FP16_TO_FP32 (b_y0 -> s ),
2354+ GGML_FP16_TO_FP32 (b_x0 -> m ) * GGML_FP16_TO_FP32 (b_y1 -> s ),
2355+ GGML_FP16_TO_FP32 (b_x1 -> m ) * GGML_FP16_TO_FP32 (b_y1 -> s )
2356+ };
23542357 summs0 = vaddq_f32 (summs0 , vld1q_f32 (summs_t ));
23552358
23562359 const uint8x16_t m4b = vdupq_n_u8 (0x0F );
@@ -2371,10 +2374,12 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r
23712374 const int8x16_t y1_h = vld1q_s8 (b_y1 -> qs + 16 );
23722375
23732376 // mmla into int32x4_t
2374- float32_t _scale [4 ] = {GGML_FP16_TO_FP32 (b_x0 -> d )* b_y0 -> d ,
2375- GGML_FP16_TO_FP32 (b_x0 -> d )* b_y1 -> d ,
2376- GGML_FP16_TO_FP32 (b_x1 -> d )* b_y0 -> d ,
2377- GGML_FP16_TO_FP32 (b_x1 -> d )* b_y1 -> d };
2377+ float32_t _scale [4 ] = {
2378+ GGML_FP16_TO_FP32 (b_x0 -> d )* GGML_FP16_TO_FP32 (b_y0 -> d ),
2379+ GGML_FP16_TO_FP32 (b_x0 -> d )* GGML_FP16_TO_FP32 (b_y1 -> d ),
2380+ GGML_FP16_TO_FP32 (b_x1 -> d )* GGML_FP16_TO_FP32 (b_y0 -> d ),
2381+ GGML_FP16_TO_FP32 (b_x1 -> d )* GGML_FP16_TO_FP32 (b_y1 -> d )
2382+ };
23782383 float32x4_t scale = vld1q_f32 (_scale );
23792384
23802385 int8x16_t l0 = vreinterpretq_s8_s64 (vzip1q_s64 (vreinterpretq_s64_s8 (x0_l ), vreinterpretq_s64_s8 (x1_l )));
@@ -2389,15 +2394,17 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r
23892394 int8x16_t r2 = vreinterpretq_s8_s64 (vzip1q_s64 (vreinterpretq_s64_s8 (y0_h ), vreinterpretq_s64_s8 (y1_h )));
23902395 int8x16_t r3 = vreinterpretq_s8_s64 (vzip2q_s64 (vreinterpretq_s64_s8 (y0_h ), vreinterpretq_s64_s8 (y1_h )));
23912396 sumv0 = vmlaq_f32 (sumv0 ,(vcvtq_f32_s32 (vmmlaq_s32 ((vmmlaq_s32 ((vmmlaq_s32 ((vmmlaq_s32 (vdupq_n_s32 (0 ), l0 , r0 )),
2392- l1 , r1 )), l2 , r2 )), l3 , r3 ))), scale );
2397+ l1 , r1 )), l2 , r2 )), l3 , r3 ))), scale );
23932398 }
23942399
2395- float32x4_t sumv1 = vextq_f32 (sumv0 , sumv0 , 2 );
2400+ float32x4_t sumv1 = vextq_f32 (sumv0 , sumv0 , 2 );
23962401 float32x4_t sumv2 = vzip1q_f32 (sumv0 , sumv1 );
2402+
23972403 sumv2 = vaddq_f32 (sumv2 , summs0 );
23982404
23992405 vst1_f32 (s , vget_low_f32 (sumv2 ));
24002406 vst1_f32 (s + bs , vget_high_f32 (sumv2 ));
2407+
24012408 return ;
24022409 }
24032410#endif
@@ -3374,10 +3381,12 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r
33743381 const int8x16_t y1_l = vld1q_s8 (b_y1 -> qs );
33753382 const int8x16_t y1_h = vld1q_s8 (b_y1 -> qs + 16 );
33763383
3377- float32_t _scale [4 ] = {GGML_FP16_TO_FP32 (b_x0 -> d )* GGML_FP16_TO_FP32 (b_y0 -> d ),
3378- GGML_FP16_TO_FP32 (b_x0 -> d )* GGML_FP16_TO_FP32 (b_y1 -> d ),
3379- GGML_FP16_TO_FP32 (b_x1 -> d )* GGML_FP16_TO_FP32 (b_y0 -> d ),
3380- GGML_FP16_TO_FP32 (b_x1 -> d )* GGML_FP16_TO_FP32 (b_y1 -> d )};
3384+ float32_t _scale [4 ] = {
3385+ GGML_FP16_TO_FP32 (b_x0 -> d )* GGML_FP16_TO_FP32 (b_y0 -> d ),
3386+ GGML_FP16_TO_FP32 (b_x0 -> d )* GGML_FP16_TO_FP32 (b_y1 -> d ),
3387+ GGML_FP16_TO_FP32 (b_x1 -> d )* GGML_FP16_TO_FP32 (b_y0 -> d ),
3388+ GGML_FP16_TO_FP32 (b_x1 -> d )* GGML_FP16_TO_FP32 (b_y1 -> d )
3389+ };
33813390 float32x4_t scale = vld1q_f32 (_scale );
33823391
33833392 int8x16_t l0 = vreinterpretq_s8_s64 (vzip1q_s64 (vreinterpretq_s64_s8 (x0_l ), vreinterpretq_s64_s8 (x1_l )));
@@ -3393,13 +3402,15 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r
33933402 int8x16_t r3 = vreinterpretq_s8_s64 (vzip2q_s64 (vreinterpretq_s64_s8 (y0_h ), vreinterpretq_s64_s8 (y1_h )));
33943403
33953404 sumv0 = vmlaq_f32 (sumv0 ,(vcvtq_f32_s32 (vmmlaq_s32 ((vmmlaq_s32 ((vmmlaq_s32 ((vmmlaq_s32 (vdupq_n_s32 (0 ), l0 , r0 )),
3396- l1 , r1 )), l2 , r2 )), l3 , r3 ))), scale );
3405+ l1 , r1 )), l2 , r2 )), l3 , r3 ))), scale );
33973406 }
3398- float32x4_t sumv1 = vextq_f32 (sumv0 , sumv0 , 2 );
3407+
3408+ float32x4_t sumv1 = vextq_f32 (sumv0 , sumv0 , 2 );
33993409 float32x4_t sumv2 = vzip1q_f32 (sumv0 , sumv1 );
34003410
3401- vst1_f32 (s , vget_low_f32 (sumv2 ));
3411+ vst1_f32 (s , vget_low_f32 (sumv2 ));
34023412 vst1_f32 (s + bs , vget_high_f32 (sumv2 ));
3413+
34033414 return ;
34043415 }
34053416#endif
0 commit comments