@@ -3528,26 +3528,27 @@ static void mul_mat_q4_0_r8_q8_1_avx2(int n, const void * vx, size_t bx, const D
35283528template <int nrc_y>
35293529static void mul_mat_iq1_s_r4_q8_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
35303530 GGML_ASSERT(nrc_x%4 == 0);
3531- Q8<nrc_y, block_q8_1_x4 > q8(info);
3531+ Q8<nrc_y, block_q8_K128 > q8(info);
35323532 int nb = n / 32;
35333533 GGML_ASSERT(nb%4 == 0);
35343534 __m256i qx[4];
35353535 __m256 acc[nrc_y] = {};
35363536 auto m1 = _mm256_set1_epi16(1);
35373537 auto ms = _mm_set1_epi16(-32768);
3538- float d8[8 *nrc_y];
3538+ float d8[4 *nrc_y];
35393539 union { __m256i vec; uint16_t val[16]; } helper;
35403540 struct aux_iq1_s_r4 {
35413541 uint8_t qs[16];
35423542 uint64_t qh;
35433543 };
3544- for (int ix= 0; ix < nrc_x; ix += 4) {
3544+ for (int ix = 0; ix < nrc_x; ix += 4) {
35453545 auto dptr = (const ggml_half *)((const char *)vx + ix*bx);
35463546 auto d1 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)dptr));
35473547 auto x = (const aux_iq1_s_r4 *)(dptr + 4);
35483548 for (int ib = 0; ib < nb/4; ++ib) {
35493549 for (int iy = 0; iy < nrc_y; ++iy) {
3550- _mm256_storeu_ps(d8 + 8*iy, _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8.y[iy][ib].d)));
3550+ auto bsums = _mm_cvtepi16_epi32(_mm_loadl_epi64((const __m128i *)q8.y[iy][ib].bsums));
3551+ _mm_storeu_ps(d8 + 4*iy, _mm_mul_ps(_mm_set1_ps(q8.y[iy][ib].d), _mm_cvtepi32_ps(bsums)));
35513552 }
35523553 for (int k = 0; k < 4; ++k) {
35533554 auto idxh = _mm256_set1_epi64x(x[4*ib+k].qh);
@@ -3556,8 +3557,8 @@ static void mul_mat_iq1_s_r4_q8_1(int n, const void * vx, size_t bx, const DataI
35563557 scales4 = _mm_or_si128(_mm_slli_epi16(scales4, 1), _mm_set1_epi16(1));
35573558 auto signs = _mm_or_si128(_mm_cmpeq_epi16(_mm_and_si128(sas, ms), ms), _mm256_castsi256_si128(m1));
35583559 signs = _mm_add_epi16(_mm_set1_epi16(-8), signs);
3559- auto delta4 = _mm_mul_ps(_mm_set1_ps(0.0625f), _mm_cvtepi32_ps(_mm_cvtepi16_epi32(
3560- _mm_mullo_epi16(scales4, signs) )));
3560+ signs = _mm_mullo_epi16(signs, scales4);
3561+ auto delta4 = _mm_mul_ps(_mm_set1_ps(0.0625f), _mm_cvtepi32_ps(_mm_cvtepi16_epi32( signs)));
35613562 auto delta = _mm256_set_m128(delta4, delta4);
35623563 scales4 = _mm_unpacklo_epi16(scales4, scales4); // 0,0, 1,1, 2,2, 3,3
35633564 auto scales = MM256_SET_M128I(scales4, scales4);
@@ -3598,8 +3599,8 @@ static void mul_mat_iq1_s_r4_q8_1(int n, const void * vx, size_t bx, const DataI
35983599 auto sumi = _mm256_packs_epi32(sumi1, sumi2);
35993600#endif
36003601 sumi = _mm256_madd_epi16(scales, sumi);
3601- acc[iy] = _mm256_fmadd_ps(_mm256_set1_ps(d8[8*iy+k+0] ), _mm256_cvtepi32_ps(sumi), acc[iy]);
3602- acc[iy] = _mm256_fmadd_ps(_mm256_set1_ps(d8[8 *iy+k+4 ]), delta, acc[iy]);
3602+ acc[iy] = _mm256_fmadd_ps(_mm256_set1_ps(q8.y[iy][ib].d ), _mm256_cvtepi32_ps(sumi), acc[iy]);
3603+ acc[iy] = _mm256_fmadd_ps(_mm256_set1_ps(d8[4 *iy+k]), delta, acc[iy]);
36033604 }
36043605 }
36053606 }
@@ -3614,7 +3615,7 @@ static void mul_mat_iq1_s_r4_q8_1(int n, const void * vx, size_t bx, const DataI
36143615template <int nrc_y>
36153616static void mul_mat_iq1_m_r4_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
36163617 GGML_ASSERT(nrc_x%4 == 0);
3617- Q8<nrc_y, block_q8_0_x4 > q8(info);
3618+ Q8<nrc_y, block_q8_K128 > q8(info);
36183619 int nb = n / 32;
36193620 GGML_ASSERT(nb%4 == 0);
36203621 auto shuffle0 = _mm256_set_epi64x(0x0909090909090909, 0x0808080808080808, 0x0101010101010101, 0x0000000000000000);
@@ -3624,17 +3625,14 @@ static void mul_mat_iq1_m_r4_q8_0(int n, const void * vx, size_t bx, const DataI
36243625#endif
36253626 __m256i qx[4];
36263627 __m256 acc[nrc_y] = {};
3628+ __m256i isum[nrc_y] = {};
36273629 auto ms = _mm_set1_epi8(0x08);
3628- float d8[4*nrc_y];
36293630 union { __m256i vec; uint16_t val[16]; } helper;
36303631 for (int ix= 0; ix < nrc_x; ix += 4) {
36313632 auto dptr = (const ggml_half *)((const char *)vx + ix*bx);
36323633 auto d1 = _mm_mul_ps(_mm_set1_ps(0.125f), _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)dptr)));
36333634 auto x = (const block_iq1_m_r4 *)(dptr + 4);
36343635 for (int ib = 0; ib < nb/4; ++ib) {
3635- for (int iy = 0; iy < nrc_y; ++iy) {
3636- _mm_storeu_ps(d8 + 4*iy, _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)q8.y[iy][ib].d)));
3637- }
36383636 for (int k = 0; k < 4; ++k) {
36393637 auto qh = (const uint32_t *)x[4*ib+k].qh;
36403638 auto idxh = _mm_set_epi32(qh[1] >> 4, qh[1], qh[0] >> 4, qh[0]);
@@ -3694,10 +3692,13 @@ static void mul_mat_iq1_m_r4_q8_0(int n, const void * vx, size_t bx, const DataI
36943692 // 0,0, 1,1, 2,2, 3,3, 0,0, 1,1, 2,2, 3,3 as int16_t
36953693 auto sumi = _mm256_packs_epi32(sumi1, sumi2);
36963694#endif
3697- sumi = _mm256_madd_epi16(scales, sumi);
3698- acc[iy] = _mm256_fmadd_ps(_mm256_set1_ps(d8[4*iy+k]), _mm256_cvtepi32_ps(sumi), acc[iy]);
3695+ isum[iy] = _mm256_add_epi32(isum[iy], _mm256_madd_epi16(scales, sumi));
36993696 }
37003697 }
3698+ for (int iy = 0; iy < nrc_y; ++iy) {
3699+ acc[iy] = _mm256_fmadd_ps(_mm256_set1_ps(q8.y[iy][ib].d), _mm256_cvtepi32_ps(isum[iy]), acc[iy]);
3700+ isum[iy] = _mm256_setzero_si256();
3701+ }
37013702 }
37023703 for (int iy = 0; iy < nrc_y; ++iy) {
37033704 auto sumf = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1));
@@ -9177,7 +9178,7 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
91779178#ifdef HAVE_FANCY_SIMD
91789179 mm.func16 = mul_mat_iq1_s_r4_q8_1<16>;
91799180#endif
9180- expected_typeB = GGML_TYPE_Q8_1_X4 ;
9181+ expected_typeB = GGML_TYPE_Q8_K128 ;
91819182 break;
91829183 case GGML_TYPE_IQ1_M_R4:
91839184 assert (ne00 % QK4_NL == 0);
@@ -9192,7 +9193,7 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
91929193#ifdef HAVE_FANCY_SIMD
91939194 mm.func16 = mul_mat_iq1_m_r4_q8_0<16>;
91949195#endif
9195- expected_typeB = GGML_TYPE_Q8_0_X4 ;
9196+ expected_typeB = GGML_TYPE_Q8_K128 ;
91969197 break;
91979198
91989199 default:
@@ -12072,7 +12073,7 @@ static void mul_mat_iq2_xs_r4_q8_k(int n, const void * vx, size_t bx, const Data
1207212073
1207312074static void mul_mat_iq1_s_r4_q8_1_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
1207412075 GGML_ASSERT(nrc_x%4 == 0);
12075- Q8<1, block_q8_1_x4 > q8(info);
12076+ Q8<1, block_q8_K128 > q8(info);
1207612077 int nb = n / 32;
1207712078 GGML_ASSERT(nb%4 == 0);
1207812079 int8x16_t qx[8];
@@ -12084,8 +12085,8 @@ static void mul_mat_iq1_s_r4_q8_1_1(int n, const void * vx, size_t bx, const Dat
1208412085 auto d1 = vcvt_f32_f16(vld1_f16((const float16_t *)dptr));
1208512086 auto x = (const block_iq1_s_r4 *)(dptr + 4);
1208612087 for (int ib = 0; ib < nb/4; ++ib) {
12087- auto scale_yd = vcvt_f32_f16(vld1_f16((const float16_t *) q8.y[0][ib].d+0) );
12088- auto scale_ym = vcvt_f32_f16(vld1_f16((const float16_t *) q8.y[0][ib].d+4 ));
12088+ auto scale_yd = vdupq_n_f32( q8.y[0][ib].d);
12089+ auto scale_ym = vmulq_f32(scale_yd, vcvtq_f32_s32(vmovl_s16(vld1_s16( q8.y[0][ib].bsums)) ));
1208912090 for (int k = 0; k < 4; ++k) {
1209012091 auto sas = vld1_u16(x[4*ib+k].qh);
1209112092 auto scales4 = vand_u16(vshr_n_u16(sas, 12), vdup_n_u16(7));
@@ -12135,23 +12136,22 @@ static void mul_mat_iq1_s_r4_q8_1_1(int n, const void * vx, size_t bx, const Dat
1213512136template <int nrc_y>
1213612137static void mul_mat_iq1_s_r4_q8_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
1213712138 GGML_ASSERT(nrc_x%4 == 0);
12138- Q8<nrc_y, block_q8_1_x4 > q8(info);
12139+ Q8<nrc_y, block_q8_K128 > q8(info);
1213912140 int nb = n / 32;
1214012141 GGML_ASSERT(nb%4 == 0);
1214112142 uint8x16_t qx[8];
1214212143 int32x4_t acc[nrc_y] = {};
1214312144 auto ms = vdup_n_u16(0x8000);
1214412145 auto mask = vdupq_n_s8(0x03);
12145- float d8[8 *nrc_y];
12146+ float d8[4 *nrc_y];
1214612147 for (int ix= 0; ix < nrc_x; ix += 4) {
1214712148 auto dptr = (const ggml_half *)((const char *)vx + ix*bx);
1214812149 auto d1 = vcvt_f32_f16(vld1_f16((const float16_t *)dptr));
1214912150 auto x = (const block_iq1_s_r4 *)(dptr + 4);
1215012151 for (int ib = 0; ib < nb/4; ++ib) {
1215112152 for (int iy = 0; iy < nrc_y; ++iy) {
12152- auto scales = vld1q_f16((const float16_t *)q8.y[iy][ib].d);
12153- vst1q_f32(d8+8*iy+0, vcvt_f32_f16(vget_low_f16(scales)));
12154- vst1q_f32(d8+8*iy+4, vcvt_f32_f16(vget_high_f16(scales)));
12153+ auto scales = vcvtq_f32_s32(vmovl_s16(vld1_s16(q8.y[iy][ib].bsums)));
12154+ vst1q_f32(d8+4*iy, vmulq_f32(vdupq_n_f32(q8.y[iy][ib].d), scales));
1215512155 }
1215612156 for (int k = 0; k < 4; ++k) {
1215712157 auto sas = vld1_u16(x[4*ib+k].qh);
@@ -12193,8 +12193,8 @@ static void mul_mat_iq1_s_r4_q8_1(int n, const void * vx, size_t bx, const DataI
1219312193 sumi = vdotq_laneq_s32(sumi, vreinterpretq_s8_u8(qx[6]), y.val[1], 2);
1219412194 sumi = vdotq_laneq_s32(sumi, vreinterpretq_s8_u8(qx[7]), y.val[1], 3);
1219512195 sumi = vmulq_s32(scales, sumi);
12196- acc[iy] = vfmaq_f32(acc[iy], vdupq_n_f32(d8[8*iy+k+0] ), vcvtq_f32_s32(sumi));
12197- acc[iy] = vfmaq_f32(acc[iy], vdupq_n_f32(d8[8 *iy+k+4 ]), delta4);
12196+ acc[iy] = vfmaq_f32(acc[iy], vdupq_n_f32(q8.y[iy][ib].d ), vcvtq_f32_s32(sumi));
12197+ acc[iy] = vfmaq_f32(acc[iy], vdupq_n_f32(d8[4 *iy+k]), delta4);
1219812198 }
1219912199 }
1220012200 }
@@ -12208,25 +12208,21 @@ static void mul_mat_iq1_s_r4_q8_1(int n, const void * vx, size_t bx, const DataI
1220812208template <int nrc_y>
1220912209static void mul_mat_iq1_m_r4_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
1221012210 GGML_ASSERT(nrc_x%4 == 0);
12211- Q8<nrc_y, block_q8_0_x4 > q8(info);
12211+ Q8<nrc_y, block_q8_K128 > q8(info);
1221212212 int nb = n / 32;
1221312213 GGML_ASSERT(nb%4 == 0);
1221412214 int8x16_t qx[8];
12215- int32x4_t acc[nrc_y] = {};
12215+ float32x4_t acc[nrc_y] = {};
12216+ int32x4_t isum[nrc_y] = {};
1221612217 auto shuffle0 = uint32x4_t{0x00000000, 0x01010101, 0x02020202, 0x03030303};
1221712218 auto step = vdupq_n_u8(4);
1221812219 auto ms = vdupq_n_u8(0x08);
1221912220 auto mask = vdupq_n_s8(0x18);
12220- float d8[4*nrc_y];
1222112221 for (int ix= 0; ix < nrc_x; ix += 4) {
1222212222 auto dptr = (const ggml_half *)((const char *)vx + ix*bx);
1222312223 auto d1 = vmulq_f32(vdupq_n_f32(0.125f), vcvt_f32_f16(vld1_f16((const float16_t *)dptr)));
1222412224 auto x = (const block_iq1_m_r4 *)(dptr + 4);
1222512225 for (int ib = 0; ib < nb/4; ++ib) {
12226- for (int iy = 0; iy < nrc_y; ++iy) {
12227- auto scales = vld1_f16((const float16_t *)q8.y[iy][ib].d);
12228- vst1q_f32(d8+4*iy, vcvt_f32_f16(scales));
12229- }
1223012226 for (int k = 0; k < 4; ++k) {
1223112227 auto scales4 = vdup_n_u32(((const uint32_t *)x[4*ib+k].scales)[0]);
1223212228 scales4 = vand_u8(vshl_u32(scales4, int32x2_t{0, -4}), vdup_n_u8(0xf));
@@ -12272,10 +12268,13 @@ static void mul_mat_iq1_m_r4_q8_0(int n, const void * vx, size_t bx, const DataI
1227212268 sumi2 = vdotq_laneq_s32(sumi2, vreinterpretq_s8_u8(qx[5]), y.val[1], 1);
1227312269 sumi2 = vdotq_laneq_s32(sumi2, vreinterpretq_s8_u8(qx[6]), y.val[1], 2);
1227412270 sumi2 = vdotq_laneq_s32(sumi2, vreinterpretq_s8_u8(qx[7]), y.val[1], 3);
12275- auto sumi = vmlaq_s32(vmlaq_s32(vdupq_n_s32(0), sumi1, scales1), sumi2, scales2);
12276- acc[iy] = vfmaq_f32(acc[iy], vdupq_n_f32(d8[4*iy+k]), vcvtq_f32_s32(sumi));
12271+ isum[iy] = vmlaq_s32(vmlaq_s32(isum[iy], sumi1, scales1), sumi2, scales2);
1227712272 }
1227812273 }
12274+ for (int iy = 0; iy < nrc_y; ++iy) {
12275+ acc[iy] = vfmaq_f32(acc[iy], vdupq_n_f32(q8.y[iy][ib].d), vcvtq_f32_s32(isum[iy]));
12276+ isum[iy] = vdupq_n_s32(0);
12277+ }
1227912278 }
1228012279 for (int iy = 0; iy < nrc_y; ++iy) {
1228112280 info.store(ix, iy, vmulq_f32(d1, acc[iy]));
@@ -13907,12 +13906,12 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) {
1390713906 SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq1_s_r4_q8_1);
1390813907 m.funcs[0] = mul_mat_iq1_s_r4_q8_1_1;
1390913908 m.func16 = mul_mat_iq1_s_r4_q8_1<16>;
13910- expected_Btype = GGML_TYPE_Q8_1_X4 ;
13909+ expected_Btype = GGML_TYPE_Q8_K128 ;
1391113910 break;
1391213911 case GGML_TYPE_IQ1_M_R4:
1391313912 SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq1_m_r4_q8_0);
1391413913 m.func16 = mul_mat_iq1_m_r4_q8_0<16>;
13915- expected_Btype = GGML_TYPE_Q8_0_X4 ;
13914+ expected_Btype = GGML_TYPE_Q8_K128 ;
1391613915 break;
1391713916 case GGML_TYPE_IQ3_XXS_R4:
1391813917 SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq3_xxs_r4_q8_k);
0 commit comments