@@ -1585,6 +1585,7 @@ void mul_mat_iq4_kt_F32_T(int n, const void * vx, size_t bx, const DataInfo& inf
15851585 }
15861586}
15871587
1588+ template <bool is_abs = false >
15881589struct Trellis3 {
15891590 constexpr static uint32_t ka = 0xCBAC1FED ;
15901591 constexpr static uint32_t ka1 = ka*ka;
@@ -1611,6 +1612,9 @@ struct Trellis3 {
16111612 i8 .val [1 ] = vandq_u32 (i8 .val [1 ], vdupq_n_u32 (0x3f3f3f3f ));
16121613 auto s2 = vpaddq_s8 (vreinterpretq_s8_u32 (i8 .val [0 ]), vreinterpretq_s8_u32 (i8 .val [1 ]));
16131614 result.val [i] = vaddq_s8 (result.val [i], vpaddq_s8 (s1, s2));
1615+ if constexpr (is_abs) {
1616+ result.val [i] = vabsq_s8 (result.val [i]);
1617+ }
16141618 }
16151619 return result;
16161620 }
@@ -1630,6 +1634,9 @@ struct Trellis3 {
16301634 i8 .val [1 ] = vandq_u32 (i8 .val [1 ], vdupq_n_u32 (0x3f3f3f3f ));
16311635 auto s2 = vpaddq_s8 (vreinterpretq_s8_u32 (i8 .val [0 ]), vreinterpretq_s8_u32 (i8 .val [1 ]));
16321636 result.val [i] = vaddq_s8 (result.val [i], vpaddq_s8 (s1, s2));
1637+ if constexpr (is_abs) {
1638+ result.val [i] = vreinterpretq_s8_u8 (vabsq_s8 (result.val [i]));
1639+ }
16331640 }
16341641 return result;
16351642 }
@@ -1657,6 +1664,9 @@ struct Trellis3 {
16571664 result.val [i+0 ] = vaddq_s8 (result.val [i+0 ], vpaddq_s8 (s1_1, s2_1));
16581665 result.val [i+2 ] = vaddq_s8 (result.val [i+2 ], vpaddq_s8 (s1_2, s2_2));
16591666 }
1667+ if constexpr (is_abs) {
1668+ for (int i = 0 ; i < 4 ; ++i) result.val [i] = vabsq_s8 (result.val [i]);
1669+ }
16601670 return result;
16611671 }
16621672 static uint8x16_t load_shuffle () {
@@ -1872,6 +1882,69 @@ void iqk_dequantize_iq2_kt_q80_r8(int n, const void * vx, size_t bx, void * vy,
18721882 }
18731883}
18741884
1885+ void iqk_dequantize_iq3_kt_q80_r8 (int n, const void * vx, size_t bx, void * vy, int nrc_x) {
1886+ GGML_ASSERT (n%QK_K == 0 );
1887+ GGML_ASSERT (nrc_x%8 == 0 );
1888+ const int nb = n/QK_K;
1889+
1890+ Trellis3<true > trellis;
1891+
1892+ block_q8_0_r8 * y = (block_q8_0_r8 *)vy;
1893+
1894+ const block_iq3_kt * x8[8 ];
1895+
1896+ float dkt[8 ];
1897+ float ls[8 ], ls_all[64 ];
1898+ uint32_t idx[8 ];
1899+ uint32_t sign_bits[16 ];
1900+
1901+ for (int ix = 0 ; ix < nrc_x; ix += 8 ) {
1902+ for (int k = 0 ; k < 8 ; ++k) {
1903+ const float * dptr = (const float *)((const char *)vx + (ix+k)*bx);
1904+ dkt[k] = dptr[0 ] * 1 .05f ;
1905+ x8[k] = (const block_iq3_kt *)(dptr + 1 );
1906+ }
1907+ auto vd = vld1q_f32_x2 (dkt);
1908+
1909+ for (int i = 0 ; i < nb; ++i) {
1910+ for (int k = 0 ; k < 8 ; ++k) {
1911+ auto u32 = *(const uint32_t *)x8[k][i].scales ;
1912+ auto s8_u32 = uint32x2_t {u32 , u32 >> 4 };
1913+ s8_u32 = vand_u8 (s8_u32, vdup_n_u32 (0x0f0f0f0f ));
1914+ auto s16 = vmovl_s8 (vreinterpret_s8_u32 (s8_u32));
1915+ vst1q_f32 (ls_all + 8 *k + 0 , vcvtq_f32_s32 (vmovl_s16 (vget_low_s16 (s16))));
1916+ vst1q_f32 (ls_all + 8 *k + 4 , vcvtq_f32_s32 (vmovl_s16 (vget_high_s16 (s16))));
1917+ }
1918+ auto mask = vdupq_n_u8 (1 );
1919+ for (int ib = 0 ; ib < QK_K/32 ; ++ib) {
1920+ for (int k = 0 ; k < 8 ; ++k) ls[k] = ls_all[8 *k+ib];
1921+ auto scales1 = vmulq_f32 (vd.val [0 ], vld1q_f32 (ls+0 ));
1922+ auto scales2 = vmulq_f32 (vd.val [1 ], vld1q_f32 (ls+4 ));
1923+ vst1_f16 ((float16_t *)y[ib].d +0 , vcvt_f16_f32 (scales1));
1924+ vst1_f16 ((float16_t *)y[ib].d +4 , vcvt_f16_f32 (scales2));
1925+ for (int j = 0 ; j < 4 ; ++j) {
1926+ for (int k = 0 ; k < 8 ; ++k) {
1927+ const uint16_t * ql = (const uint16_t *)x8[k][i].ql ;
1928+ idx[k] = ql[4 *ib+j] + 4096 ;
1929+ auto qh = (const uint32_t *)x8[k][i].qh ;
1930+ sign_bits[k+0 ] = qh[2 *j+0 ];
1931+ sign_bits[k+8 ] = qh[2 *j+1 ];
1932+ }
1933+ auto packed = trellis.next64 (idx);
1934+ auto signs = vld1q_u8_x4 ((const uint8_t *)sign_bits);
1935+ for (int l = 0 ; l < 4 ; ++l) {
1936+ auto s = vorrq_u8 (vceqq_u8 (vandq_u8 (signs.val [l], mask), mask), vdupq_n_u8 (1 ));
1937+ packed.val [l] = vmulq_s8 (packed.val [l], vreinterpretq_s8_u8 (s));
1938+ }
1939+ vst1q_s8_x4 (y[ib].qs +64 *j, packed);
1940+ }
1941+ mask = vshlq_n_u8 (mask, 1 );
1942+ }
1943+ y += 8 ; // = QK_K/32;
1944+ }
1945+ }
1946+ }
1947+
18751948template <int nrc_y>
18761949void mul_mat_iq2_kt_q8_0_x4_T (int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
18771950 assert (n%QK_K == 0 );
@@ -1974,6 +2047,126 @@ void mul_mat_iq2_kt_q8_0_x4_T(int n, const void * vx, size_t bx, const DataInfo&
19742047 }
19752048}
19762049
2050+ template <int nrc_y>
2051+ void mul_mat_iq3_kt_q8_0_x4_T (int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
2052+ assert (n%QK_K == 0 );
2053+ const int nb = n/QK_K;
2054+
2055+ Trellis3<true > trellis;
2056+
2057+ constexpr int k_acc = nrc_y == 1 ? 2 : nrc_y;
2058+
2059+ float32x4_t accd[k_acc];
2060+
2061+ const block_q8_0_x4 * y[nrc_y];
2062+ for (int iy = 0 ; iy < nrc_y; ++iy) {
2063+ y[iy] = (const block_q8_0_x4 *)info.src1_row (iy);
2064+ }
2065+
2066+ int8x16x2_t xv[8 ];
2067+ int32x4x4_t dot;
2068+
2069+ auto compute_dot = [&dot] (const int8_t * y, const int8x16x2_t * xv) {
2070+ for (int k = 0 ; k < 4 ; ++k) {
2071+ auto yv = vld1q_s8_x2 (y + 32 *k);
2072+ dot.val [k] = vdotq_s32 (vdotq_s32 (vdupq_n_s32 (0 ), xv[k].val [0 ], yv.val [0 ]), xv[k].val [1 ], yv.val [1 ]);
2073+ }
2074+ dot.val [0 ] = vpaddq_s32 (dot.val [0 ], dot.val [1 ]);
2075+ dot.val [2 ] = vpaddq_s32 (dot.val [2 ], dot.val [3 ]);
2076+ return vpaddq_s32 (dot.val [0 ], dot.val [2 ]);
2077+ };
2078+
2079+ float32x4x2_t scales;
2080+ auto mask = vdupq_n_u8 (1 );
2081+ auto maskh = vdupq_n_u8 (0x10 );
2082+
2083+ for (int ix = 0 ; ix < nrc_x; ++ix) {
2084+ const float * dptr = (const float *)((const char *)vx + ix*bx);
2085+ auto d = vdupq_n_f32 (dptr[0 ]*1 .05f );
2086+ const block_iq3_kt * x = (const block_iq3_kt *)(dptr + 1 );
2087+
2088+ for (int iy = 0 ; iy < k_acc; ++iy) accd[iy] = vdupq_n_f32 (0 );
2089+
2090+ for (int i = 0 ; i < nb; ++i) {
2091+ auto u32 = *(const uint32_t *)x[i].scales ;
2092+ auto s8_u32 = uint32x2_t {u32 , u32 >> 4 };
2093+ s8_u32 = vand_u8 (s8_u32, vdup_n_u32 (0x0f0f0f0f ));
2094+ auto s16 = vmovl_s8 (vreinterpret_s8_u32 (s8_u32));
2095+ scales.val [0 ] = vmulq_f32 (d, vcvtq_f32_s32 (vmovl_s16 (vget_low_s16 (s16))));
2096+ scales.val [1 ] = vmulq_f32 (d, vcvtq_f32_s32 (vmovl_s16 (vget_high_s16 (s16))));
2097+ const uint16_t * ql = (const uint16_t *)x[i].ql ;
2098+ auto sign_bits = vld1q_u8_x2 (x[i].qh );
2099+ if constexpr (nrc_y == 1 ) {
2100+ const block_q8_0_x4& ybl = y[0 ][2 *i+0 ];
2101+ const block_q8_0_x4& ybh = y[0 ][2 *i+1 ];
2102+ auto dyl = vmulq_f32 (scales.val [0 ], vcvt_f32_f16 (vld1_f16 ((const float16_t *)ybl.d )));
2103+ auto dyh = vmulq_f32 (scales.val [1 ], vcvt_f32_f16 (vld1_f16 ((const float16_t *)ybh.d )));
2104+ int32x4x4_t suml = {};
2105+ int32x4x4_t sumh = {};
2106+ for (int ib = 0 ; ib < 4 ; ++ib) {
2107+ auto xl = trellis.next32 (ql + 4 *ib + 0 , 4096 );
2108+ auto signs1 = vorrq_u8 (vceqq_u8 (vandq_u8 (sign_bits.val [0 ], mask), mask), mask);
2109+ auto signs2 = vorrq_u8 (vceqq_u8 (vandq_u8 (sign_bits.val [1 ], mask), mask), mask);
2110+ xl.val [0 ] = vmulq_s8 (xl.val [0 ], vreinterpretq_s8_u8 (signs1));
2111+ xl.val [1 ] = vmulq_s8 (xl.val [1 ], vreinterpretq_s8_u8 (signs2));
2112+ auto xh = trellis.next32 (ql + 4 *ib + 16 , 4096 );
2113+ signs1 = vorrq_u8 (vceqq_u8 (vandq_u8 (sign_bits.val [0 ], maskh), maskh), mask);
2114+ signs2 = vorrq_u8 (vceqq_u8 (vandq_u8 (sign_bits.val [1 ], maskh), maskh), mask);
2115+ xh.val [0 ] = vmulq_s8 (xh.val [0 ], vreinterpretq_s8_u8 (signs1));
2116+ xh.val [1 ] = vmulq_s8 (xh.val [1 ], vreinterpretq_s8_u8 (signs2));
2117+ auto yl = vld1q_s8_x2 (ybl.qs + 32 *ib);
2118+ auto yh = vld1q_s8_x2 (ybh.qs + 32 *ib);
2119+ suml.val [ib] = vdotq_s32 (vdotq_s32 (vdupq_n_s32 (0 ), xl.val [0 ], yl.val [0 ]), xl.val [1 ], yl.val [1 ]);
2120+ sumh.val [ib] = vdotq_s32 (vdotq_s32 (vdupq_n_s32 (0 ), xh.val [0 ], yh.val [0 ]), xh.val [1 ], yh.val [1 ]);
2121+ sign_bits.val [0 ] = vshrq_n_u8 (sign_bits.val [0 ], 1 );
2122+ sign_bits.val [1 ] = vshrq_n_u8 (sign_bits.val [1 ], 1 );
2123+ }
2124+ auto sl1 = vpaddq_s32 (suml.val [0 ], suml.val [1 ]);
2125+ auto sl2 = vpaddq_s32 (suml.val [2 ], suml.val [3 ]);
2126+ auto sl = vpaddq_s32 (sl1, sl2);
2127+ auto sh1 = vpaddq_s32 (sumh.val [0 ], sumh.val [1 ]);
2128+ auto sh2 = vpaddq_s32 (sumh.val [2 ], sumh.val [3 ]);
2129+ auto sh = vpaddq_s32 (sh1, sh2);
2130+ accd[0 ] = vfmaq_f32 (accd[0 ], dyl, vcvtq_f32_s32 (sl));
2131+ accd[1 ] = vfmaq_f32 (accd[1 ], dyh, vcvtq_f32_s32 (sh));
2132+ } else {
2133+ for (int k = 0 ; k < 8 ; ++k) {
2134+ xv[k] = trellis.next32 (ql + 4 *k, 4096 );
2135+ auto signs1 = vorrq_u8 (vceqq_u8 (vandq_u8 (sign_bits.val [0 ], mask), mask), mask);
2136+ auto signs2 = vorrq_u8 (vceqq_u8 (vandq_u8 (sign_bits.val [1 ], mask), mask), mask);
2137+ xv[k].val [0 ] = vmulq_s8 (xv[k].val [0 ], vreinterpretq_s8_u8 (signs1));
2138+ xv[k].val [1 ] = vmulq_s8 (xv[k].val [1 ], vreinterpretq_s8_u8 (signs2));
2139+ sign_bits.val [0 ] = vshrq_n_u8 (sign_bits.val [0 ], 1 );
2140+ sign_bits.val [1 ] = vshrq_n_u8 (sign_bits.val [1 ], 1 );
2141+ }
2142+ for (int iy = 0 ; iy < nrc_y; ++iy) {
2143+ const block_q8_0_x4& ybl = y[iy][2 *i+0 ];
2144+ const block_q8_0_x4& ybh = y[iy][2 *i+1 ];
2145+ auto dyl = vmulq_f32 (scales.val [0 ], vcvt_f32_f16 (vld1_f16 ((const float16_t *)ybl.d )));
2146+ auto dyh = vmulq_f32 (scales.val [1 ], vcvt_f32_f16 (vld1_f16 ((const float16_t *)ybh.d )));
2147+ auto sumil = compute_dot (ybl.qs , xv+0 );
2148+ auto sumih = compute_dot (ybh.qs , xv+4 );
2149+ if constexpr (nrc_y == 1 ) {
2150+ accd[2 *iy+0 ] = vfmaq_f32 (accd[2 *iy+0 ], dyl, vcvtq_f32_s32 (sumil));
2151+ accd[2 *iy+1 ] = vfmaq_f32 (accd[2 *iy+1 ], dyh, vcvtq_f32_s32 (sumih));
2152+ } else {
2153+ accd[iy] = vfmaq_f32 (accd[iy], dyl, vcvtq_f32_s32 (sumil));
2154+ accd[iy] = vfmaq_f32 (accd[iy], dyh, vcvtq_f32_s32 (sumih));
2155+ }
2156+ }
2157+ }
2158+ }
2159+
2160+ if constexpr (nrc_y == 1 ) {
2161+ info.store (ix, 0 , vaddvq_f32 (vaddq_f32 (accd[0 ], accd[1 ])));
2162+ } else {
2163+ for (int iy = 0 ; iy < nrc_y; ++iy) {
2164+ info.store (ix, iy, vaddvq_f32 (accd[iy]));
2165+ }
2166+ }
2167+ }
2168+ }
2169+
19772170}
19782171
19792172bool iqk_set_kernels_ktquants (int ne00, int typeA, int typeB, std::array<mul_mat_t , IQK_MAX_NY>& kernels, mul_mat_t & func16) {
@@ -1990,6 +2183,15 @@ bool iqk_set_kernels_ktquants(int ne00, int typeA, int typeB, std::array<mul_mat
19902183 return false ;
19912184 }
19922185
2186+ if (ggml_type (typeA) == GGML_TYPE_IQ3_KT) {
2187+ if (ggml_type (typeB) == GGML_TYPE_Q8_0_X4) {
2188+ IQK_SET_MUL_MAT_FUNCTIONS (mul_mat_iq3_kt_q8_0_x4_T, kernels);
2189+ func16 = nullptr ;
2190+ return true ;
2191+ }
2192+ return false ;
2193+ }
2194+
19932195 if (ggml_type (typeA) == GGML_TYPE_IQ2_KT) {
19942196 if (ggml_type (typeB) == GGML_TYPE_Q8_0_X4) {
19952197 IQK_SET_MUL_MAT_FUNCTIONS (mul_mat_iq2_kt_q8_0_x4_T, kernels);
@@ -2022,10 +2224,10 @@ bool iqk_set_kernels_ktquants(int ne00, int typeA, int typeB, std::array<mul_mat
20222224 return true ;
20232225}
20242226
2025- bool iqk_dequantize_ktquants (int type, int n, const void * vx, size_t bx, void * y, size_t stride_y, int nrc_x) {
2227+ bool iqk_dequantize_ktquants (int type, int n, const void * vx, size_t bx, void * y, [[maybe_unused]] size_t stride_y, int nrc_x) {
20262228 switch (type) {
20272229 case GGML_TYPE_IQ2_KT: iqk_dequantize_iq2_kt_q80_r8 (n, vx, bx, y, nrc_x); break ;
2028- case GGML_TYPE_IQ3_KT: iqk_dequantize_iq3_kt (n, vx, bx, ( float16_t *)y, stride_y , nrc_x); break ;
2230+ case GGML_TYPE_IQ3_KT: iqk_dequantize_iq3_kt_q80_r8 (n, vx, bx, y , nrc_x); break ;
20292231 case GGML_TYPE_IQ4_KT: iqk_dequantize_iq4_kt_q80_r8 (n, vx, bx, y, nrc_x); break ;
20302232 default : return false ;
20312233 }
0 commit comments