@@ -98,7 +98,7 @@ struct Trellis2 {
9898};
9999
100100
101- template <bool is_8 = false >
101+ template <bool is_8 = false , bool is_abs = false >
102102struct Trellis3 {
103103 constexpr static uint32_t ka = 0xCBAC1FED ;
104104 constexpr static uint32_t ka1 = ka*ka;
@@ -127,7 +127,11 @@ struct Trellis3 {
127127 auto dot = _mm256_maddubs_epi16 (v8, _mm256_set1_epi32 (0x01010101 ));
128128 auto i8 = _mm256_add_epi32 (_mm256_set1_epi32 (-126 ), _mm256_madd_epi16 (dot, _mm256_set1_epi16 (1 )));
129129#endif
130- return _mm256_cvtepi32_ps (i8 );
130+ if constexpr (is_abs) {
131+ return _mm256_cvtepi32_ps (_mm256_sign_epi32 (i8 , i8 ));
132+ } else {
133+ return _mm256_cvtepi32_ps (i8 );
134+ }
131135 }
132136 inline __m256 gen8 (uint32_t val) const {
133137 auto v8 = _mm256_and_si256 (next8 (val), _mm256_set1_epi32 (0x3f3f3f3f ));
@@ -137,11 +141,14 @@ struct Trellis3 {
137141 auto dot = _mm256_maddubs_epi16 (v8, _mm256_set1_epi32 (0x01010101 ));
138142 auto i8 = _mm256_add_epi32 (_mm256_set1_epi32 (-126 ), _mm256_madd_epi16 (dot, _mm256_set1_epi16 (1 )));
139143#endif
140- return _mm256_cvtepi32_ps (i8 );
144+ if constexpr (is_abs) {
145+ return _mm256_cvtepi32_ps (_mm256_sign_epi32 (i8 , i8 ));
146+ } else {
147+ return _mm256_cvtepi32_ps (i8 );
148+ }
141149 }
142- template <bool is_unsigned = false >
143150 inline __m256i next32 (const uint32_t * val) const {
144- const __m256i offset = is_unsigned ? _mm256_setzero_si256 () : _mm256_set1_epi32 (-126 );
151+ const __m256i offset = _mm256_set1_epi32 (-126 );
145152 __m256i aux[4 ];
146153 for (int i = 0 ; i < 4 ; ++i) {
147154 auto i8 = _mm256_and_si256 (next8 (val[2 *i+0 ], val[2 *i+1 ]), _mm256_set1_epi32 (0x3f3f3f3f ));
@@ -156,11 +163,15 @@ struct Trellis3 {
156163 aux[2 ] = _mm256_packs_epi32 (aux[2 ], aux[3 ]); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31
157164 aux[0 ] = _mm256_packs_epi16 (aux[0 ], aux[2 ]); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27
158165 // 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31
159- return _mm256_permutevar8x32_epi32 (aux[0 ], shuffle);
166+ if constexpr (is_abs) {
167+ auto result = _mm256_permutevar8x32_epi32 (aux[0 ], shuffle);
168+ return _mm256_sign_epi8 (result, result);
169+ } else {
170+ return _mm256_permutevar8x32_epi32 (aux[0 ], shuffle);
171+ }
160172 }
161- template <bool is_unsigned = false >
162173 inline __m256i next32 (const uint16_t * val, uint32_t v0) const {
163- const __m256i offset = is_unsigned ? _mm256_setzero_si256 () : _mm256_set1_epi32 (-126 );
174+ const __m256i offset = _mm256_set1_epi32 (-126 );
164175 __m256i aux[4 ];
165176 for (int i = 0 ; i < 4 ; ++i) {
166177 auto i8 = _mm256_and_si256 (next8 (v0 + val[i]), _mm256_set1_epi32 (0x3f3f3f3f ));
@@ -175,11 +186,15 @@ struct Trellis3 {
175186 aux[2 ] = _mm256_packs_epi32 (aux[2 ], aux[3 ]); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31
176187 aux[0 ] = _mm256_packs_epi16 (aux[0 ], aux[2 ]); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27
177188 // 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31
178- return _mm256_permutevar8x32_epi32 (aux[0 ], shuffle);
189+ if constexpr (is_abs) {
190+ auto result = _mm256_permutevar8x32_epi32 (aux[0 ], shuffle);
191+ return _mm256_sign_epi8 (result, result);
192+ } else {
193+ return _mm256_permutevar8x32_epi32 (aux[0 ], shuffle);
194+ }
179195 }
180- template <bool is_unsigned = false >
181196 inline void next64 (const uint32_t * val, __m256i * result) const {
182- const __m256i offset = is_unsigned ? _mm256_setzero_si256 () : _mm256_set1_epi32 (-126 );
197+ const __m256i offset = _mm256_set1_epi32 (-126 );
183198 auto vka3 = _mm256_set1_epi32 (ka3);
184199 __m256i aux[8 ];
185200 for (int i = 0 ; i < 4 ; ++i) {
@@ -203,6 +218,9 @@ struct Trellis3 {
203218 aux[4 *k+0 ] = _mm256_packs_epi16 (aux[4 *k+0 ], aux[4 *k+2 ]); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27
204219 // 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31
205220 result[k] = _mm256_permutevar8x32_epi32 (aux[4 *k+0 ], shuffle);
221+ if constexpr (is_abs) {
222+ result[k] = _mm256_sign_epi8 (result[k], result[k]);
223+ }
206224 }
207225 }
208226};
@@ -449,6 +467,70 @@ void mul_mat_iq2_kt_q8_2_x4_T(int n, const void * vx, size_t bx, const DataInfo&
449467 }
450468}
451469
470+ void iqk_dequantize_iq3_kt_q80_r8 (int n, const void * vx, size_t bx, void * vy, int nrc_x) {
471+ GGML_ASSERT (n%QK_K == 0 );
472+ GGML_ASSERT (nrc_x%8 == 0 );
473+ const int nb = n/QK_K;
474+
475+ Trellis3<false , true > trellis;
476+
477+ auto shifts = _mm_set_epi32 (0 , 0 , 4 , 0 );
478+
479+ block_q8_0_r8 * y = (block_q8_0_r8 *)vy;
480+
481+ const block_iq3_kt * x8[8 ];
482+ float dkt[8 ];
483+ float ls[8 ];
484+ float ls_all[64 ];
485+ uint32_t idx[8 ];
486+ uint32_t sign_bits[16 ];
487+
488+ for (int ix = 0 ; ix < nrc_x; ix += 8 ) {
489+ for (int k = 0 ; k < 8 ; ++k) {
490+ const float * dptr = (const float *)((const char *)vx + (ix+k)*bx);
491+ dkt[k] = dptr[0 ];
492+ x8[k] = (const block_iq3_kt *)(dptr + 1 );
493+ }
494+ auto vd = _mm256_mul_ps (_mm256_set1_ps (1 .01f ), _mm256_loadu_ps (dkt));
495+
496+ for (int i = 0 ; i < nb; ++i) {
497+ for (int k = 0 ; k < 8 ; ++k) {
498+ auto s8 = _mm_set1_epi32 (*(const uint32_t *)x8[k][i].scales );
499+ s8 = _mm_and_si128 (_mm_srlv_epi32 (s8, shifts), _mm_set1_epi8 (0xf ));
500+ auto s32 = _mm256_cvtepi8_epi32 (s8);
501+ _mm256_storeu_ps (ls_all + 8 *k, _mm256_cvtepi32_ps (s32));
502+ }
503+ auto mask = _mm256_set1_epi8 (1 );
504+ for (int ib = 0 ; ib < QK_K/32 ; ++ib) {
505+ for (int k = 0 ; k < 8 ; ++k) ls[k] = ls_all[8 *k+ib];
506+ auto scales = _mm256_mul_ps (vd, _mm256_loadu_ps (ls));
507+ _mm_storeu_si128 ((__m128i *)y[ib].d , _mm256_cvtps_ph (scales, _MM_FROUND_TO_NEAREST_INT));
508+ for (int j = 0 ; j < 4 ; ++j) {
509+ for (int k = 0 ; k < 8 ; ++k) {
510+ const uint16_t * ql = (const uint16_t *)x8[k][i].ql ;
511+ idx[k] = ql[4 *ib+j] + 4096 ;
512+ auto qh = (const uint32_t *)x8[k][i].qh ;
513+ sign_bits[k+0 ] = qh[2 *j+0 ];
514+ sign_bits[k+8 ] = qh[2 *j+1 ];
515+ }
516+ __m256i packed[2 ];
517+ trellis.next64 (idx, packed);
518+ auto signs1 = _mm256_loadu_si256 ((const __m256i *)sign_bits+0 );
519+ auto signs2 = _mm256_loadu_si256 ((const __m256i *)sign_bits+1 );
520+ signs1 = _mm256_or_si256 (_mm256_cmpeq_epi8 (_mm256_and_si256 (signs1, mask), mask), _mm256_set1_epi8 (1 ));
521+ signs2 = _mm256_or_si256 (_mm256_cmpeq_epi8 (_mm256_and_si256 (signs2, mask), mask), _mm256_set1_epi8 (1 ));
522+ packed[0 ] = _mm256_sign_epi8 (packed[0 ], signs1);
523+ packed[1 ] = _mm256_sign_epi8 (packed[1 ], signs2);
524+ _mm256_storeu_si256 ((__m256i *)y[ib].qs +2 *j+0 , packed[0 ]);
525+ _mm256_storeu_si256 ((__m256i *)y[ib].qs +2 *j+1 , packed[1 ]);
526+ }
527+ mask = _mm256_slli_epi16 (mask, 1 );
528+ }
529+ y += 8 ; // = QK_K/32;
530+ }
531+ }
532+ }
533+
452534inline __m256 abs_ps (__m256 vals) {
453535 // Clear sign-bit of all the 32-bit floats in vals
454536 __m256 sign_bit = _mm256_set1_ps (-0 .0f );
@@ -887,10 +969,10 @@ bool iqk_set_kernels_ktquants(int ne00, int typeA, int typeB, std::array<mul_mat
887969
888970}
889971
890- bool iqk_dequantize_ktquants (int type, int n, const void * vx, size_t bx, void * y, size_t stride_y, int nrc_x) {
972+ bool iqk_dequantize_ktquants (int type, int n, const void * vx, size_t bx, void * y, [[maybe_unused]] size_t stride_y, int nrc_x) {
891973 switch (type) {
892974 case GGML_TYPE_IQ2_KT: iqk_dequantize_iq2_kt_q80_r8 (n, vx, bx, y, nrc_x); break ;
893- case GGML_TYPE_IQ3_KT: iqk_dequantize_iq3_kt (n, vx, bx, ( float *)y, stride_y , nrc_x); break ;
975+ case GGML_TYPE_IQ3_KT: iqk_dequantize_iq3_kt_q80_r8 (n, vx, bx, y , nrc_x); break ;
894976 case GGML_TYPE_IQ4_KT: iqk_dequantize_iq4_kt_q80_r8 (n, vx, bx, y, nrc_x); break ;
895977 default : return false ;
896978 }
0 commit comments