Skip to content

Commit 1843ed2

Browse files
ikawrakowIwan Kawrakow
andauthored
New integer trellis on ARM_NEON (#544)
* Adapt iq3_kt to new trellis on NEON * iq3_kt is now working on NEON --------- Co-authored-by: Iwan Kawrakow <[email protected]>
1 parent 144ee1c commit 1843ed2

File tree

2 files changed

+206
-4
lines changed

2 files changed

+206
-4
lines changed

ggml/src/iqk/iqk_gemm_ktquants.cpp

Lines changed: 204 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1585,6 +1585,7 @@ void mul_mat_iq4_kt_F32_T(int n, const void * vx, size_t bx, const DataInfo& inf
15851585
}
15861586
}
15871587

1588+
template <bool is_abs = false>
15881589
struct Trellis3 {
15891590
constexpr static uint32_t ka = 0xCBAC1FED;
15901591
constexpr static uint32_t ka1 = ka*ka;
@@ -1611,6 +1612,9 @@ struct Trellis3 {
16111612
i8.val[1] = vandq_u32(i8.val[1], vdupq_n_u32(0x3f3f3f3f));
16121613
auto s2 = vpaddq_s8(vreinterpretq_s8_u32(i8.val[0]), vreinterpretq_s8_u32(i8.val[1]));
16131614
result.val[i] = vaddq_s8(result.val[i], vpaddq_s8(s1, s2));
1615+
if constexpr (is_abs) {
1616+
result.val[i] = vabsq_s8(result.val[i]);
1617+
}
16141618
}
16151619
return result;
16161620
}
@@ -1630,6 +1634,9 @@ struct Trellis3 {
16301634
i8.val[1] = vandq_u32(i8.val[1], vdupq_n_u32(0x3f3f3f3f));
16311635
auto s2 = vpaddq_s8(vreinterpretq_s8_u32(i8.val[0]), vreinterpretq_s8_u32(i8.val[1]));
16321636
result.val[i] = vaddq_s8(result.val[i], vpaddq_s8(s1, s2));
1637+
if constexpr (is_abs) {
1638+
result.val[i] = vreinterpretq_s8_u8(vabsq_s8(result.val[i]));
1639+
}
16331640
}
16341641
return result;
16351642
}
@@ -1657,6 +1664,9 @@ struct Trellis3 {
16571664
result.val[i+0] = vaddq_s8(result.val[i+0], vpaddq_s8(s1_1, s2_1));
16581665
result.val[i+2] = vaddq_s8(result.val[i+2], vpaddq_s8(s1_2, s2_2));
16591666
}
1667+
if constexpr (is_abs) {
1668+
for (int i = 0; i < 4; ++i) result.val[i] = vabsq_s8(result.val[i]);
1669+
}
16601670
return result;
16611671
}
16621672
static uint8x16_t load_shuffle() {
@@ -1872,6 +1882,69 @@ void iqk_dequantize_iq2_kt_q80_r8(int n, const void * vx, size_t bx, void * vy,
18721882
}
18731883
}
18741884

1885+
void iqk_dequantize_iq3_kt_q80_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) {
1886+
GGML_ASSERT(n%QK_K == 0);
1887+
GGML_ASSERT(nrc_x%8 == 0);
1888+
const int nb = n/QK_K;
1889+
1890+
Trellis3<true> trellis;
1891+
1892+
block_q8_0_r8 * y = (block_q8_0_r8 *)vy;
1893+
1894+
const block_iq3_kt * x8[8];
1895+
1896+
float dkt[8];
1897+
float ls[8], ls_all[64];
1898+
uint32_t idx[8];
1899+
uint32_t sign_bits[16];
1900+
1901+
for (int ix = 0; ix < nrc_x; ix += 8) {
1902+
for (int k = 0; k < 8; ++k) {
1903+
const float * dptr = (const float *)((const char*)vx + (ix+k)*bx);
1904+
dkt[k] = dptr[0] * 1.05f;
1905+
x8[k] = (const block_iq3_kt *)(dptr + 1);
1906+
}
1907+
auto vd = vld1q_f32_x2(dkt);
1908+
1909+
for (int i = 0; i < nb; ++i) {
1910+
for (int k = 0; k < 8; ++k) {
1911+
auto u32 = *(const uint32_t *)x8[k][i].scales;
1912+
auto s8_u32 = uint32x2_t{u32, u32 >> 4};
1913+
s8_u32 = vand_u8(s8_u32, vdup_n_u32(0x0f0f0f0f));
1914+
auto s16 = vmovl_s8(vreinterpret_s8_u32(s8_u32));
1915+
vst1q_f32(ls_all + 8*k + 0, vcvtq_f32_s32(vmovl_s16(vget_low_s16(s16))));
1916+
vst1q_f32(ls_all + 8*k + 4, vcvtq_f32_s32(vmovl_s16(vget_high_s16(s16))));
1917+
}
1918+
auto mask = vdupq_n_u8(1);
1919+
for (int ib = 0; ib < QK_K/32; ++ib) {
1920+
for (int k = 0; k < 8; ++k) ls[k] = ls_all[8*k+ib];
1921+
auto scales1 = vmulq_f32(vd.val[0], vld1q_f32(ls+0));
1922+
auto scales2 = vmulq_f32(vd.val[1], vld1q_f32(ls+4));
1923+
vst1_f16((float16_t *)y[ib].d+0, vcvt_f16_f32(scales1));
1924+
vst1_f16((float16_t *)y[ib].d+4, vcvt_f16_f32(scales2));
1925+
for (int j = 0; j < 4; ++j) {
1926+
for (int k = 0; k < 8; ++k) {
1927+
const uint16_t * ql = (const uint16_t *)x8[k][i].ql;
1928+
idx[k] = ql[4*ib+j] + 4096;
1929+
auto qh = (const uint32_t *)x8[k][i].qh;
1930+
sign_bits[k+0] = qh[2*j+0];
1931+
sign_bits[k+8] = qh[2*j+1];
1932+
}
1933+
auto packed = trellis.next64(idx);
1934+
auto signs = vld1q_u8_x4((const uint8_t *)sign_bits);
1935+
for (int l = 0; l < 4; ++l) {
1936+
auto s = vorrq_u8(vceqq_u8(vandq_u8(signs.val[l], mask), mask), vdupq_n_u8(1));
1937+
packed.val[l] = vmulq_s8(packed.val[l], vreinterpretq_s8_u8(s));
1938+
}
1939+
vst1q_s8_x4(y[ib].qs+64*j, packed);
1940+
}
1941+
mask = vshlq_n_u8(mask, 1);
1942+
}
1943+
y += 8; // = QK_K/32;
1944+
}
1945+
}
1946+
}
1947+
18751948
template <int nrc_y>
18761949
void mul_mat_iq2_kt_q8_0_x4_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
18771950
assert(n%QK_K == 0);
@@ -1974,6 +2047,126 @@ void mul_mat_iq2_kt_q8_0_x4_T(int n, const void * vx, size_t bx, const DataInfo&
19742047
}
19752048
}
19762049

2050+
template <int nrc_y>
2051+
void mul_mat_iq3_kt_q8_0_x4_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
2052+
assert(n%QK_K == 0);
2053+
const int nb = n/QK_K;
2054+
2055+
Trellis3<true> trellis;
2056+
2057+
constexpr int k_acc = nrc_y == 1 ? 2 : nrc_y;
2058+
2059+
float32x4_t accd[k_acc];
2060+
2061+
const block_q8_0_x4 * y[nrc_y];
2062+
for (int iy = 0; iy < nrc_y; ++iy) {
2063+
y[iy] = (const block_q8_0_x4 *)info.src1_row(iy);
2064+
}
2065+
2066+
int8x16x2_t xv[8];
2067+
int32x4x4_t dot;
2068+
2069+
auto compute_dot = [&dot] (const int8_t * y, const int8x16x2_t * xv) {
2070+
for (int k = 0; k < 4; ++k) {
2071+
auto yv = vld1q_s8_x2(y + 32*k);
2072+
dot.val[k] = vdotq_s32(vdotq_s32(vdupq_n_s32(0), xv[k].val[0], yv.val[0]), xv[k].val[1], yv.val[1]);
2073+
}
2074+
dot.val[0] = vpaddq_s32(dot.val[0], dot.val[1]);
2075+
dot.val[2] = vpaddq_s32(dot.val[2], dot.val[3]);
2076+
return vpaddq_s32(dot.val[0], dot.val[2]);
2077+
};
2078+
2079+
float32x4x2_t scales;
2080+
auto mask = vdupq_n_u8(1);
2081+
auto maskh = vdupq_n_u8(0x10);
2082+
2083+
for (int ix = 0; ix < nrc_x; ++ix) {
2084+
const float * dptr = (const float *)((const char*)vx + ix*bx);
2085+
auto d = vdupq_n_f32(dptr[0]*1.05f);
2086+
const block_iq3_kt * x = (const block_iq3_kt *)(dptr + 1);
2087+
2088+
for (int iy = 0; iy < k_acc; ++iy) accd[iy] = vdupq_n_f32(0);
2089+
2090+
for (int i = 0; i < nb; ++i) {
2091+
auto u32 = *(const uint32_t *)x[i].scales;
2092+
auto s8_u32 = uint32x2_t{u32, u32 >> 4};
2093+
s8_u32 = vand_u8(s8_u32, vdup_n_u32(0x0f0f0f0f));
2094+
auto s16 = vmovl_s8(vreinterpret_s8_u32(s8_u32));
2095+
scales.val[0] = vmulq_f32(d, vcvtq_f32_s32(vmovl_s16(vget_low_s16 (s16))));
2096+
scales.val[1] = vmulq_f32(d, vcvtq_f32_s32(vmovl_s16(vget_high_s16(s16))));
2097+
const uint16_t * ql = (const uint16_t *)x[i].ql;
2098+
auto sign_bits = vld1q_u8_x2(x[i].qh);
2099+
if constexpr (nrc_y == 1) {
2100+
const block_q8_0_x4& ybl = y[0][2*i+0];
2101+
const block_q8_0_x4& ybh = y[0][2*i+1];
2102+
auto dyl = vmulq_f32(scales.val[0], vcvt_f32_f16(vld1_f16((const float16_t *)ybl.d)));
2103+
auto dyh = vmulq_f32(scales.val[1], vcvt_f32_f16(vld1_f16((const float16_t *)ybh.d)));
2104+
int32x4x4_t suml = {};
2105+
int32x4x4_t sumh = {};
2106+
for (int ib = 0; ib < 4; ++ib) {
2107+
auto xl = trellis.next32(ql + 4*ib + 0, 4096);
2108+
auto signs1 = vorrq_u8(vceqq_u8(vandq_u8(sign_bits.val[0], mask), mask), mask);
2109+
auto signs2 = vorrq_u8(vceqq_u8(vandq_u8(sign_bits.val[1], mask), mask), mask);
2110+
xl.val[0] = vmulq_s8(xl.val[0], vreinterpretq_s8_u8(signs1));
2111+
xl.val[1] = vmulq_s8(xl.val[1], vreinterpretq_s8_u8(signs2));
2112+
auto xh = trellis.next32(ql + 4*ib + 16, 4096);
2113+
signs1 = vorrq_u8(vceqq_u8(vandq_u8(sign_bits.val[0], maskh), maskh), mask);
2114+
signs2 = vorrq_u8(vceqq_u8(vandq_u8(sign_bits.val[1], maskh), maskh), mask);
2115+
xh.val[0] = vmulq_s8(xh.val[0], vreinterpretq_s8_u8(signs1));
2116+
xh.val[1] = vmulq_s8(xh.val[1], vreinterpretq_s8_u8(signs2));
2117+
auto yl = vld1q_s8_x2(ybl.qs + 32*ib);
2118+
auto yh = vld1q_s8_x2(ybh.qs + 32*ib);
2119+
suml.val[ib] = vdotq_s32(vdotq_s32(vdupq_n_s32(0), xl.val[0], yl.val[0]), xl.val[1], yl.val[1]);
2120+
sumh.val[ib] = vdotq_s32(vdotq_s32(vdupq_n_s32(0), xh.val[0], yh.val[0]), xh.val[1], yh.val[1]);
2121+
sign_bits.val[0] = vshrq_n_u8(sign_bits.val[0], 1);
2122+
sign_bits.val[1] = vshrq_n_u8(sign_bits.val[1], 1);
2123+
}
2124+
auto sl1 = vpaddq_s32(suml.val[0], suml.val[1]);
2125+
auto sl2 = vpaddq_s32(suml.val[2], suml.val[3]);
2126+
auto sl = vpaddq_s32(sl1, sl2);
2127+
auto sh1 = vpaddq_s32(sumh.val[0], sumh.val[1]);
2128+
auto sh2 = vpaddq_s32(sumh.val[2], sumh.val[3]);
2129+
auto sh = vpaddq_s32(sh1, sh2);
2130+
accd[0] = vfmaq_f32(accd[0], dyl, vcvtq_f32_s32(sl));
2131+
accd[1] = vfmaq_f32(accd[1], dyh, vcvtq_f32_s32(sh));
2132+
} else {
2133+
for (int k = 0; k < 8; ++k) {
2134+
xv[k] = trellis.next32(ql + 4*k, 4096);
2135+
auto signs1 = vorrq_u8(vceqq_u8(vandq_u8(sign_bits.val[0], mask), mask), mask);
2136+
auto signs2 = vorrq_u8(vceqq_u8(vandq_u8(sign_bits.val[1], mask), mask), mask);
2137+
xv[k].val[0] = vmulq_s8(xv[k].val[0], vreinterpretq_s8_u8(signs1));
2138+
xv[k].val[1] = vmulq_s8(xv[k].val[1], vreinterpretq_s8_u8(signs2));
2139+
sign_bits.val[0] = vshrq_n_u8(sign_bits.val[0], 1);
2140+
sign_bits.val[1] = vshrq_n_u8(sign_bits.val[1], 1);
2141+
}
2142+
for (int iy = 0; iy < nrc_y; ++iy) {
2143+
const block_q8_0_x4& ybl = y[iy][2*i+0];
2144+
const block_q8_0_x4& ybh = y[iy][2*i+1];
2145+
auto dyl = vmulq_f32(scales.val[0], vcvt_f32_f16(vld1_f16((const float16_t *)ybl.d)));
2146+
auto dyh = vmulq_f32(scales.val[1], vcvt_f32_f16(vld1_f16((const float16_t *)ybh.d)));
2147+
auto sumil = compute_dot(ybl.qs, xv+0);
2148+
auto sumih = compute_dot(ybh.qs, xv+4);
2149+
if constexpr (nrc_y == 1) {
2150+
accd[2*iy+0] = vfmaq_f32(accd[2*iy+0], dyl, vcvtq_f32_s32(sumil));
2151+
accd[2*iy+1] = vfmaq_f32(accd[2*iy+1], dyh, vcvtq_f32_s32(sumih));
2152+
} else {
2153+
accd[iy] = vfmaq_f32(accd[iy], dyl, vcvtq_f32_s32(sumil));
2154+
accd[iy] = vfmaq_f32(accd[iy], dyh, vcvtq_f32_s32(sumih));
2155+
}
2156+
}
2157+
}
2158+
}
2159+
2160+
if constexpr (nrc_y == 1) {
2161+
info.store(ix, 0, vaddvq_f32(vaddq_f32(accd[0], accd[1])));
2162+
} else {
2163+
for (int iy = 0; iy < nrc_y; ++iy) {
2164+
info.store(ix, iy, vaddvq_f32(accd[iy]));
2165+
}
2166+
}
2167+
}
2168+
}
2169+
19772170
}
19782171

19792172
bool iqk_set_kernels_ktquants(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels, mul_mat_t& func16) {
@@ -1990,6 +2183,15 @@ bool iqk_set_kernels_ktquants(int ne00, int typeA, int typeB, std::array<mul_mat
19902183
return false;
19912184
}
19922185

2186+
if (ggml_type(typeA) == GGML_TYPE_IQ3_KT) {
2187+
if (ggml_type(typeB) == GGML_TYPE_Q8_0_X4) {
2188+
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq3_kt_q8_0_x4_T, kernels);
2189+
func16 = nullptr;
2190+
return true;
2191+
}
2192+
return false;
2193+
}
2194+
19932195
if (ggml_type(typeA) == GGML_TYPE_IQ2_KT) {
19942196
if (ggml_type(typeB) == GGML_TYPE_Q8_0_X4) {
19952197
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq2_kt_q8_0_x4_T, kernels);
@@ -2022,10 +2224,10 @@ bool iqk_set_kernels_ktquants(int ne00, int typeA, int typeB, std::array<mul_mat
20222224
return true;
20232225
}
20242226

2025-
bool iqk_dequantize_ktquants(int type, int n, const void * vx, size_t bx, void * y, size_t stride_y, int nrc_x) {
2227+
bool iqk_dequantize_ktquants(int type, int n, const void * vx, size_t bx, void * y, [[maybe_unused]] size_t stride_y, int nrc_x) {
20262228
switch (type) {
20272229
case GGML_TYPE_IQ2_KT: iqk_dequantize_iq2_kt_q80_r8(n, vx, bx, y, nrc_x); break;
2028-
case GGML_TYPE_IQ3_KT: iqk_dequantize_iq3_kt(n, vx, bx, (float16_t *)y, stride_y, nrc_x); break;
2230+
case GGML_TYPE_IQ3_KT: iqk_dequantize_iq3_kt_q80_r8(n, vx, bx, y, nrc_x); break;
20292231
case GGML_TYPE_IQ4_KT: iqk_dequantize_iq4_kt_q80_r8(n, vx, bx, y, nrc_x); break;
20302232
default: return false;
20312233
}

ggml/src/iqk/iqk_mul_mat.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@ struct MulMat {
272272
#else
273273
switch (type) {
274274
case GGML_TYPE_IQ2_KT: return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type;
275-
case GGML_TYPE_IQ3_KT: return nrc_y >= 32 ? GGML_TYPE_F16 : type;
275+
case GGML_TYPE_IQ3_KT: return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type;
276276
case GGML_TYPE_IQ4_KT: return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type;
277277
default: break;
278278
}
@@ -435,7 +435,7 @@ bool iqk_convert_repack(int typeA, int n, const void * vx, size_t bx, void * vy,
435435
return iqk_convert_1bit_q80_r8(typeA, n, vx, bx, vy, nrc_x);
436436

437437
default:
438-
return false;
438+
break;
439439
}
440440

441441
return false;

0 commit comments

Comments
 (0)