@@ -5509,6 +5509,150 @@ void vec_dot_q8_k_r8_q8_k(int n, float * s, size_t bs, const void * vx, size_t b
55095509 GGML_UNUSED (by);
55105510}
55115511
5512+ //
5513+ // ========================================= q8_KV_r8
5514+ //
5515+
5516+ void quantize_row_q8_KV_r8_ref (const float * x, void * y, int64_t k) {
5517+ quantize_q8_KV_r8 (x, y, 8 , k/8 , nullptr );
5518+ }
5519+
5520+ void quantize_row_q8_KV_r8 (const float * x, void * y, int64_t k) {
5521+ quantize_q8_KV_r8 (x, y, 8 , k/8 , nullptr );
5522+ }
5523+
5524+ static void repack_q8_KV (int nrows, int n_per_row, const char * cx, char * cy, [[maybe_unused]] bool online) {
5525+ GGML_ASSERT (nrows%8 == 0 );
5526+ GGML_ASSERT (n_per_row%16 == 0 );
5527+ auto row_size_x = ggml_row_size (GGML_TYPE_Q8_KV, n_per_row);
5528+ auto row_size_y = ggml_row_size (GGML_TYPE_Q8_KV_R8, n_per_row);
5529+ const int8_t * x8[8 ];
5530+ #ifdef __ARM_NEON
5531+ int8x16x2_t m0, m1, m2, m3;
5532+ #endif
5533+ for (int row = 0 ; row < nrows; row += 8 ) {
5534+ auto dy = (float *)cy;
5535+ auto qy = (int8_t *)(dy + 8 );
5536+ for (int k = 0 ; k < 8 ; ++k) {
5537+ auto dx = (const float *)(cx + k*row_size_x);
5538+ dy[k] = dx[0 ];
5539+ x8[k] = (const int8_t *)(dx + 2 );
5540+ }
5541+ for (int ib = 0 ; ib < n_per_row/16 ; ++ib) {
5542+ #ifdef __AVX2__
5543+ #define MM256_SET_M128I (a, b ) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1 )
5544+ auto m0 = MM256_SET_M128I (_mm_loadu_si128 ((const __m128i *)x8[4 ]+ib), _mm_loadu_si128 ((const __m128i *)x8[0 ]+ib));
5545+ auto m1 = MM256_SET_M128I (_mm_loadu_si128 ((const __m128i *)x8[5 ]+ib), _mm_loadu_si128 ((const __m128i *)x8[1 ]+ib));
5546+ auto m2 = MM256_SET_M128I (_mm_loadu_si128 ((const __m128i *)x8[6 ]+ib), _mm_loadu_si128 ((const __m128i *)x8[2 ]+ib));
5547+ auto m3 = MM256_SET_M128I (_mm_loadu_si128 ((const __m128i *)x8[7 ]+ib), _mm_loadu_si128 ((const __m128i *)x8[3 ]+ib));
5548+ auto t0 = _mm256_unpacklo_epi32 (m0, m1);
5549+ auto t1 = _mm256_unpacklo_epi32 (m2, m3);
5550+ auto t2 = _mm256_unpackhi_epi32 (m0, m1);
5551+ auto t3 = _mm256_unpackhi_epi32 (m2, m3);
5552+ m0 = _mm256_unpacklo_epi64 (t0, t1);
5553+ m1 = _mm256_unpackhi_epi64 (t0, t1);
5554+ m2 = _mm256_unpacklo_epi64 (t2, t3);
5555+ m3 = _mm256_unpackhi_epi64 (t2, t3);
5556+ #ifdef HAVE_FANCY_SIMD
5557+ if (online) {
5558+ m0 = _mm256_add_epi8 (m0, _mm256_set1_epi8 (127 ));
5559+ m1 = _mm256_add_epi8 (m1, _mm256_set1_epi8 (127 ));
5560+ m2 = _mm256_add_epi8 (m2, _mm256_set1_epi8 (127 ));
5561+ m3 = _mm256_add_epi8 (m3, _mm256_set1_epi8 (127 ));
5562+ }
5563+ #endif
5564+ _mm256_storeu_si256 ((__m256i *)qy + 4 *ib+0 , m0);
5565+ _mm256_storeu_si256 ((__m256i *)qy + 4 *ib+1 , m1);
5566+ _mm256_storeu_si256 ((__m256i *)qy + 4 *ib+2 , m2);
5567+ _mm256_storeu_si256 ((__m256i *)qy + 4 *ib+3 , m3);
5568+ #elif defined __ARM_NEON
5569+ m0.val [0 ] = vld1q_s8 (x8[0 ]+16 *ib); m0.val [1 ] = vld1q_s8 (x8[4 ]+16 *ib);
5570+ m1.val [0 ] = vld1q_s8 (x8[1 ]+16 *ib); m1.val [1 ] = vld1q_s8 (x8[5 ]+16 *ib);
5571+ m2.val [0 ] = vld1q_s8 (x8[2 ]+16 *ib); m2.val [1 ] = vld1q_s8 (x8[6 ]+16 *ib);
5572+ m3.val [0 ] = vld1q_s8 (x8[3 ]+16 *ib); m3.val [1 ] = vld1q_s8 (x8[7 ]+16 *ib);
5573+ auto row01 = vtrnq_s32 (vreinterpretq_s32_s8 (m0.val [0 ]), vreinterpretq_s32_s8 (m1.val [0 ]));
5574+ auto row23 = vtrnq_s32 (vreinterpretq_s32_s8 (m2.val [0 ]), vreinterpretq_s32_s8 (m3.val [0 ]));
5575+ m0.val [0 ] = vreinterpretq_s8_s64 (vtrn1q_s64 (vreinterpretq_s64_s32 (row01.val [0 ]), vreinterpretq_s64_s32 (row23.val [0 ])));
5576+ m1.val [0 ] = vreinterpretq_s8_s64 (vtrn1q_s64 (vreinterpretq_s64_s32 (row01.val [1 ]), vreinterpretq_s64_s32 (row23.val [1 ])));
5577+ m2.val [0 ] = vreinterpretq_s8_s64 (vtrn2q_s64 (vreinterpretq_s64_s32 (row01.val [0 ]), vreinterpretq_s64_s32 (row23.val [0 ])));
5578+ m3.val [0 ] = vreinterpretq_s8_s64 (vtrn2q_s64 (vreinterpretq_s64_s32 (row01.val [1 ]), vreinterpretq_s64_s32 (row23.val [1 ])));
5579+ row01 = vtrnq_s32 (vreinterpretq_s32_s8 (m0.val [1 ]), vreinterpretq_s32_s8 (m1.val [1 ]));
5580+ row23 = vtrnq_s32 (vreinterpretq_s32_s8 (m2.val [1 ]), vreinterpretq_s32_s8 (m3.val [1 ]));
5581+ m0.val [1 ] = vreinterpretq_s8_s64 (vtrn1q_s64 (vreinterpretq_s64_s32 (row01.val [0 ]), vreinterpretq_s64_s32 (row23.val [0 ])));
5582+ m1.val [1 ] = vreinterpretq_s8_s64 (vtrn1q_s64 (vreinterpretq_s64_s32 (row01.val [1 ]), vreinterpretq_s64_s32 (row23.val [1 ])));
5583+ m2.val [1 ] = vreinterpretq_s8_s64 (vtrn2q_s64 (vreinterpretq_s64_s32 (row01.val [0 ]), vreinterpretq_s64_s32 (row23.val [0 ])));
5584+ m3.val [1 ] = vreinterpretq_s8_s64 (vtrn2q_s64 (vreinterpretq_s64_s32 (row01.val [1 ]), vreinterpretq_s64_s32 (row23.val [1 ])));
5585+ vst1q_s8_x2 (qy + 0 + 128 *ib, m0);
5586+ vst1q_s8_x2 (qy + 32 + 128 *ib, m1);
5587+ vst1q_s8_x2 (qy + 64 + 128 *ib, m2);
5588+ vst1q_s8_x2 (qy + 96 + 128 *ib, m3);
5589+ #else
5590+ // TODO
5591+ for (int l = 0 ; l < 4 ; ++l) {
5592+ for (int k = 0 ; k < 8 ; ++k) for (int i = 0 ; i < 4 ; ++i) {
5593+ y[ib].qs [32 *l+4 *k+i+ 0 ] = x8[k][ib].qs [i+4 *l+ 0 ];
5594+ y[ib].qs [32 *l+4 *k+i+128 ] = x8[k][ib].qs [i+4 *l+16 ];
5595+ }
5596+ }
5597+ #endif
5598+
5599+ }
5600+ cx += 8 *row_size_x;
5601+ cy += online ? 8 *row_size_x : 8 *row_size_y;
5602+ // So, if we are run-time-repacking (online = true) we don't want to change the stride, so we just leave some unused space at the end of each row
5603+ }
5604+ }
5605+ #ifdef HAVE_FANCY_SIMD
5606+ static void modify_q8_KV_r8 (int64_t k, char * cy) {
5607+ int8_t * q8 = (int8_t *)(cy + 8 *sizeof (float ));
5608+ for (int j = 0 ; j < k; ++j) q8[j] += 127 ;
5609+ }
5610+ #endif
5611+
5612+ size_t quantize_q8_KV_r8 (const float * src, void * dst, int64_t nrows, int64_t n_per_row, [[maybe_unused]] const float * imatrix) {
5613+ GGML_ASSERT (nrows%8 == 0 );
5614+ GGML_ASSERT (n_per_row%16 == 0 );
5615+ char * qcur = (char *)dst;
5616+ auto row_size_0 = ggml_row_size (GGML_TYPE_Q8_KV, n_per_row);
5617+ auto row_size_1 = ggml_row_size (GGML_TYPE_Q8_KV_R8, n_per_row);
5618+ std::vector<char > qtmp (8 *row_size_0);
5619+ for (int row = 0 ; row < nrows; row += 8 ) {
5620+ quantize_q8_KV (src, (void *)qtmp.data (), 8 , n_per_row, imatrix);
5621+ repack_q8_KV (8 , n_per_row, qtmp.data (), qcur, false );
5622+ qcur += 8 *row_size_1;
5623+ src += 8 *n_per_row;
5624+ }
5625+ return nrows*row_size_1;
5626+ }
5627+
5628+ void dequantize_row_q8_KV_r8 (const void * vx, float * y, int64_t k) {
5629+ auto n_per_row = k/8 ;
5630+ float * y8[8 ];
5631+ for (int k = 0 ; k < 8 ; ++k) y8[k] = y + n_per_row*k;
5632+ auto dptr = (const float *)vx;
5633+ auto q8 = (const int8_t *)(dptr + 8 );
5634+ for (int ib = 0 ; ib < n_per_row/16 ; ++ib) {
5635+ for (int k = 0 ; k < 8 ; ++k) {
5636+ for (int l = 0 ; l < 4 ; ++l) {
5637+ for (int i = 0 ; i < 4 ; ++i) y8[k][16 *ib + 4 *l + i] = dptr[k] * q8[128 *ib + 32 *l + 4 *k + i];
5638+ }
5639+ }
5640+ }
5641+ }
5642+
5643+ void vec_dot_q8_KV_r8_q8_KV (int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) {
5644+ #if GGML_USE_IQK_MULMAT
5645+ if (iqk_mul_mat (1 , 1 , n, GGML_TYPE_Q8_KV_R8, vx, 0 , GGML_TYPE_Q8_KV, vy, 0 , s, 0 , 0 , 1 )) {
5646+ return ;
5647+ }
5648+ #endif
5649+ GGML_ASSERT (n%QK4_NL == 0 );
5650+ GGML_ASSERT (nrc == 1 );
5651+ GGML_UNUSED (bs);
5652+ GGML_UNUSED (bx);
5653+ GGML_UNUSED (by);
5654+ }
5655+
55125656//
55135657// ========================================= bf16_r4
55145658//
@@ -6610,8 +6754,9 @@ bool iqk_modify_tensor(struct ggml_tensor * tensor) {
66106754 { GGML_TYPE_Q4_0_R8, {modify_q4_0_r8, 8 } },
66116755#endif
66126756#ifdef HAVE_FANCY_SIMD
6613- { GGML_TYPE_Q8_0_R8, {modify_q8_0_r8, 8 } },
6614- { GGML_TYPE_Q8_K_R8, {modify_q8_k_r8, 8 } },
6757+ { GGML_TYPE_Q8_0_R8, {modify_q8_0_r8, 8 } },
6758+ { GGML_TYPE_Q8_K_R8, {modify_q8_k_r8, 8 } },
6759+ { GGML_TYPE_Q8_KV_R8, {modify_q8_KV_r8, 8 } },
66156760#endif
66166761 };
66176762 auto it = k_mod_map.find (tensor->type );
@@ -6670,6 +6815,7 @@ void iqk_repack_tensor(struct ggml_tensor * tensor) {
66706815 { GGML_TYPE_Q6_0, { GGML_TYPE_Q6_0_R4, 4 , (Repack::repack_func)repack_q6_0} },
66716816 { GGML_TYPE_Q8_0, { GGML_TYPE_Q8_0_R8, 8 , (Repack::repack_func)repack_q8_0} },
66726817 { GGML_TYPE_Q8_K, { GGML_TYPE_Q8_K_R8, 8 , (Repack::repack_func)repack_q8_k} },
6818+ { GGML_TYPE_Q8_KV, { GGML_TYPE_Q8_KV_R8, 8 , (Repack::repack_func)repack_q8_KV} },
66736819#ifdef __AVX512BF16__
66746820 { GGML_TYPE_BF16, { GGML_TYPE_BF16_R16, 16 , (Repack::repack_func)repack_bf16<ggml_bf16_t >}},
66756821 { GGML_TYPE_F16, { GGML_TYPE_BF16_R16, 16 , (Repack::repack_func)repack_bf16<ggml_half>} },
0 commit comments