@@ -188,20 +188,14 @@ struct ScaleHelperQ8_2 {
188188 inline __m256 prepare4 (__m256 other_scales, const Q * y) {
189189 return _mm256_mul_ps (other_scales, prepare4<Q>(y));
190190 }
191- template <typename Q> inline std::pair<float , float > prepare1 (const Q * y) const {
192- float d = GGML_BF16_TO_FP32 (y->d );
191+ template <typename Q> static inline std::pair<float , float > prepare1 (const Q * y) {
192+ float d = GGML_BF16_TO_FP32 (ggml_bf16_t { y->d } );
193193 int16_t m = *(const int16_t *)&y->s ;
194194 return std::make_pair (d, d*m);
195195 }
196- template <typename Q> inline std::pair<float , float > prepare1 (const std::pair<float , float >& dm, const Q * y) const {
197- float d = GGML_BF16_TO_FP32 (y->d );
198- int16_t m = *(const int16_t *)&y->s ;
199- return std::make_pair (dm.first *d, dm.second *d*m);
200- }
201- std::pair<float , float > inline prepare1 (const std::pair<float , float >& dm, const block_q8_2 * y) const {
202- ggml_bf16_t dy; dy.bits = y->d ; int16_t s = *(const int16_t *)&y->s ;
203- float d = GGML_BF16_TO_FP32 (dy);
204- return std::make_pair (dm.first *d, dm.second *d*s);
196+ static inline std::pair<float , float > prepare1 (const std::pair<float , float >& dm, const block_q8_2 * y) {
197+ auto d = prepare1 (y);
198+ return std::make_pair (dm.first *d.first , dm.second *d.second );
205199 }
206200};
207201
@@ -1484,14 +1478,14 @@ static void mul_mat_q8_0_r8_q8_2(int n, const void * vx, size_t bx, const DataIn
14841478 }
14851479 }
14861480 if (4 *(nb/4 ) < nb) {
1487- auto qy = (const block_q8_1 *)q8.y [0 ];
1481+ auto qy = (const block_q8_2 *)q8.y [0 ];
14881482 for (int ib = 4 *(nb/4 ); ib < nb; ++ib) {
14891483 auto scales = _mm256_cvtph_ps (_mm_loadu_si128 ((const __m128i *)iq8[ib].d ));
14901484 auto sumi = q8_0_r8_dot_product ((const uint8_t *)iq8[ib].qs , qy[ib].qs , qx);
1491- ggml_bf16_t d, s; d. bits = qy[ib]. d ; s. bits = qy[ib]. s ;
1492- auto d4d8 = _mm256_mul_ps (scales, _mm256_set1_ps (GGML_BF16_TO_FP32 (d) ));
1485+ auto [d8, m8] = ScaleHelperQ8_2::prepare1 (qy + ib) ;
1486+ auto d4d8 = _mm256_mul_ps (scales, _mm256_set1_ps (d8 ));
14931487 acc[0 ] = _mm256_fmadd_ps (d4d8, _mm256_cvtepi32_ps (sumi), acc[0 ]);
1494- acc[1 ] = _mm256_fmadd_ps (scales, _mm256_set1_ps (GGML_BF16_TO_FP32 (s) ), acc[1 ]);
1488+ acc[1 ] = _mm256_fmadd_ps (scales, _mm256_set1_ps (m8 ), acc[1 ]);
14951489 }
14961490 }
14971491 info.store (ix, 0 , _mm256_fmadd_ps (_mm256_set1_ps (-127 .f ), acc[1 ], acc[0 ]));
@@ -1535,12 +1529,12 @@ static void mul_mat_q8_0_r8_q8_2(int n, const void * vx, size_t bx, const DataIn
15351529 qx[j] = _mm512_add_epi8 (qx[j], _mm512_set1_epi8 (127 ));
15361530 }
15371531 for (int iy = 0 ; iy < nrc_y; ++iy) {
1538- auto qy = (const block_q8_1 *)q8.y [iy];
1532+ auto qy = (const block_q8_2 *)q8.y [iy];
15391533 auto sumi = qx_r8_q8_dot_product (qx, qy[ib].qs );
1540- ggml_bf16_t d, s; d. bits = qy[ib]. d ; s. bits = qy[ib]. s ;
1541- auto dy = _mm512_set1_ps (GGML_BF16_TO_FP32 (d) );
1534+ auto [d8, m8] = ScaleHelperQ8_2::prepare1 (qy + ib) ;
1535+ auto dy = _mm512_set1_ps (d8 );
15421536 acc[2 *iy+0 ] = _mm512_fmadd_ps (_mm512_mul_ps (scales, dy), _mm512_cvtepi32_ps (sumi), acc[2 *iy+0 ]);
1543- acc[2 *iy+1 ] = _mm512_fmadd_ps (scales, _mm512_set1_ps (GGML_BF16_TO_FP32 (s) ), acc[2 *iy+1 ]);
1537+ acc[2 *iy+1 ] = _mm512_fmadd_ps (scales, _mm512_set1_ps (m8 ), acc[2 *iy+1 ]);
15441538 }
15451539 }
15461540 for (int iy = 0 ; iy < nrc_y; ++iy) {
0 commit comments