Skip to content

Commit b407232

Browse files
authored
Merge branch 'ikawrakow:main' into main
2 parents 378986d + 38012f7 commit b407232

File tree

2 files changed

+69
-0
lines changed

2 files changed

+69
-0
lines changed

ggml/src/iqk/iqk_gemm_1bit.cpp

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -865,6 +865,68 @@ void mul_mat_iq1_s_q8_K(int n, const void * vx, size_t bx, const DataInfo& info,
865865
}
866866
}
867867

868+
template <int nrc_y>
869+
void mul_mat_iq1_m_q8_K(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
870+
GGML_ASSERT(n%QK_K == 0);
871+
Q8<nrc_y, block_q8_K> q8(info);
872+
__m256i qx[8];
873+
__m256 acc[nrc_y] = {};
874+
auto scale_shuffle = _mm256_set_epi64x(0x0706070607060706, 0x0504050405040504, 0x0302030203020302, 0x0100010001000100);
875+
auto delta_mask = _mm256_set_epi64x(0x8000, 0x0800, 0x0080, 0x0008);
876+
iq1m_scale_t scale;
877+
union { __m256i vec; int16_t val[16]; } helper;
878+
for (int ix = 0; ix < nrc_x; ++ix) {
879+
auto iq1m = (const block_iq1_m *)((const char *)vx + ix*bx);
880+
for (int ibl = 0; ibl < n/QK_K; ++ibl) {
881+
const uint16_t * sc = (const uint16_t *)iq1m[ibl].scales; // 4 x uint16_t, each containing 4 scales
882+
scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
883+
float d = GGML_FP16_TO_FP32(scale.f16);
884+
auto qs = iq1m[ibl].qs;
885+
auto qh = iq1m[ibl].qh;
886+
auto aux = _mm_loadl_epi64((const __m128i *)iq1m[ibl].scales);
887+
auto sc16 = _mm256_shuffle_epi8(MM256_SET_M128I(aux, aux), scale_shuffle);
888+
sc16 = _mm256_and_si256(sc16, _mm256_set1_epi64x(0x0e0001c000380007));
889+
sc16 = _mm256_mullo_epi16(sc16, _mm256_set1_epi64x(0x0001000800400200));
890+
helper.vec = _mm256_add_epi8(_mm256_srli_epi16(sc16, 8), _mm256_set1_epi16(1));
891+
for (int ib64 = 0; ib64 < QK_K/64; ++ib64) {
892+
qx[2*ib64+0] = _mm256_set_epi64x(iq1s_grid_us[qs[3] | (((uint16_t)qh[1] << 4) & 0x700)], iq1s_grid_us[qs[2] | (((uint16_t)qh[1] << 8) & 0x700)],
893+
iq1s_grid_us[qs[1] | (((uint16_t)qh[0] << 4) & 0x700)], iq1s_grid_us[qs[0] | (((uint16_t)qh[0] << 8) & 0x700)]);
894+
qx[2*ib64+1] = _mm256_set_epi64x(iq1s_grid_us[qs[7] | (((uint16_t)qh[3] << 4) & 0x700)], iq1s_grid_us[qs[6] | (((uint16_t)qh[3] << 8) & 0x700)],
895+
iq1s_grid_us[qs[5] | (((uint16_t)qh[2] << 4) & 0x700)], iq1s_grid_us[qs[4] | (((uint16_t)qh[2] << 8) & 0x700)]);
896+
//auto delta1 = _mm256_set_epi64x(qh[1] & 0x80 ? 0x0909090909090909 : 0x0707070707070707,
897+
// qh[1] & 0x08 ? 0x0909090909090909 : 0x0707070707070707,
898+
// qh[0] & 0x80 ? 0x0909090909090909 : 0x0707070707070707,
899+
// qh[0] & 0x08 ? 0x0909090909090909 : 0x0707070707070707);
900+
//auto delta2 = _mm256_set_epi64x(qh[3] & 0x80 ? 0x0909090909090909 : 0x0707070707070707,
901+
// qh[3] & 0x08 ? 0x0909090909090909 : 0x0707070707070707,
902+
// qh[2] & 0x80 ? 0x0909090909090909 : 0x0707070707070707,
903+
// qh[2] & 0x08 ? 0x0909090909090909 : 0x0707070707070707);
904+
auto qh16 = (const uint16_t *)qh;
905+
auto delta1 = _mm256_cmpeq_epi64(_mm256_and_si256(_mm256_set1_epi64x(qh16[0]), delta_mask), delta_mask);
906+
auto delta2 = _mm256_cmpeq_epi64(_mm256_and_si256(_mm256_set1_epi64x(qh16[1]), delta_mask), delta_mask);
907+
delta1 = _mm256_sub_epi8(_mm256_set1_epi8(8), _mm256_or_si256(delta1, _mm256_set1_epi8(1)));
908+
delta2 = _mm256_sub_epi8(_mm256_set1_epi8(8), _mm256_or_si256(delta2, _mm256_set1_epi8(1)));
909+
qx[2*ib64+0] = _mm256_sub_epi8(_mm256_slli_epi16(qx[2*ib64+0], 3), delta1);
910+
qx[2*ib64+1] = _mm256_sub_epi8(_mm256_slli_epi16(qx[2*ib64+1], 3), delta2);
911+
qs += 8;
912+
qh += 4;
913+
}
914+
for (int iy = 0; iy < nrc_y; ++iy) {
915+
auto sumi = _mm256_setzero_si256();
916+
for (int j = 0; j < 8; ++j) {
917+
auto p = _mm256_maddubs_epi16(_mm256_sign_epi8(qx[j], qx[j]), _mm256_sign_epi8(q8.load_quants(iy, ibl, j), qx[j]));
918+
sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(p, MM256_SET_M128I(_mm_set1_epi16(helper.val[2*j+1]), _mm_set1_epi16(helper.val[2*j+0]))));
919+
}
920+
acc[iy] = _mm256_fmadd_ps(_mm256_set1_ps(d*q8.scale(iy, ibl)), _mm256_cvtepi32_ps(sumi), acc[iy]);
921+
}
922+
}
923+
for (int iy = 0; iy < nrc_y; ++iy) {
924+
info.store(ix, iy, 0.125f*hsum_float_8(acc[iy]));
925+
acc[iy] = _mm256_setzero_ps();
926+
}
927+
}
928+
}
929+
868930
template <int nrc_y>
869931
void mul_mat_iq1_s_q8_2_x4(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
870932
GGML_ASSERT(n%QK_K == 0);
@@ -1844,6 +1906,11 @@ bool iqk_set_kernels_1bit(int ne00, int typeA, int typeB, std::array<mul_mat_t,
18441906
func16 = mul_mat_iq1_s_r4_q8_1<16>;
18451907
#endif
18461908
break;
1909+
case GGML_TYPE_IQ1_M:
1910+
if (ne00%QK_K != 0) return false;
1911+
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq1_m_q8_K, funcs);
1912+
expected_typeB = GGML_TYPE_Q8_K;
1913+
break;
18471914
case GGML_TYPE_IQ1_M_R4:
18481915
if (ne00%128 != 0) return false;
18491916
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq1_m_r4_q8_0, funcs);

ggml/src/iqk/iqk_mul_mat.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -867,6 +867,7 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
867867
case GGML_TYPE_IQ4_NL_R4:
868868
return iqk_set_kernels_legacy_quants(ne00, typeA, typeB, mm.funcs, mm.func16);
869869
case GGML_TYPE_IQ1_S:
870+
case GGML_TYPE_IQ1_M:
870871
case GGML_TYPE_IQ1_S_R4:
871872
case GGML_TYPE_IQ1_M_R4:
872873
case GGML_TYPE_IQ1_BN:
@@ -958,6 +959,7 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) {
958959
case GGML_TYPE_IQ2_BN:
959960
case GGML_TYPE_IQ2_BN_R4:
960961
case GGML_TYPE_IQ1_S:
962+
case GGML_TYPE_IQ1_M:
961963
case GGML_TYPE_IQ1_S_R4:
962964
case GGML_TYPE_IQ1_M_R4:
963965
return iqk_set_kernels_1bit(ne00, typeA, typeB, m.funcs, m.func16);

0 commit comments

Comments
 (0)