@@ -15514,8 +15514,13 @@ struct HelperQ8KVR8 : public BaseHelper<step> {
1551415514template <int D, int step>
1551515515struct HelperQ40 final : public BaseHelper<step> {
1551615516 using Base = BaseHelper<step>;
15517+ #if defined __AVX2__
1551715518 using block_q8 = block_q8_2;
1551815519 constexpr static int block_size_q = QK8_2;
15520+ #else
15521+ using block_q8 = block_q8_0;
15522+ constexpr static int block_size_q = QK8_0;
15523+ #endif
1551915524 HelperQ40(const char * data, int stride) : Base(data, stride) {}
1552015525
1552115526 // Needed for v * softmax(k * q)
@@ -15558,8 +15563,8 @@ struct HelperQ40 final : public BaseHelper<step> {
1555815563template <int D, int step>
1555915564struct HelperQ41 final : public BaseHelper<step> {
1556015565 using Base = BaseHelper<step>;
15561- using block_q8 = block_q8_1 ;
15562- constexpr static int block_size_q = QK8_1 ;
15566+ using block_q8 = block_q8_2 ;
15567+ constexpr static int block_size_q = QK8_2 ;
1556315568 HelperQ41(const char * data, int stride) : Base(data, stride) {}
1556415569
1556515570 // Needed for v * softmax(k * q)
@@ -16414,7 +16419,7 @@ struct FlashQKfp32 {
1641416419#ifdef __aarch64__
1641516420 MAKE_FUNCS(mul_mat_qX_0_q8_0<DequantizerQ40, nq);
1641616421#else
16417- MAKE_FUNCS(mul_mat_qX_1_q8_2_T<Q4_0_Unpacker , nq);
16422+ MAKE_FUNCS(mul_mat_qX_1_q8_2_T<Q4_0_1_Unpacker , nq);
1641816423#endif
1641916424 }
1642016425 else if constexpr (std::is_same_v<KHelper, HelperQ41<D, k_step>>) {
0 commit comments