diff --git a/ggml/src/ggml-cpu/arch-fallback.h b/ggml/src/ggml-cpu/arch-fallback.h index 373408a9c0955..ced359d7f329a 100644 --- a/ggml/src/ggml-cpu/arch-fallback.h +++ b/ggml/src/ggml-cpu/arch-fallback.h @@ -50,7 +50,7 @@ #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 #elif defined(__aarch64__) || defined(__arm__) || defined(_M_ARM) || defined(_M_ARM64) // repack.cpp -#define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8 +//#define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8 #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 #define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K diff --git a/ggml/src/ggml-cpu/arch/arm/repack.cpp b/ggml/src/ggml-cpu/arch/arm/repack.cpp index fdd0a513b8344..d3b94fc9e2c0e 100644 --- a/ggml/src/ggml-cpu/arch/arm/repack.cpp +++ b/ggml/src/ggml-cpu/arch/arm/repack.cpp @@ -404,6 +404,121 @@ void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo ggml_gemv_q4_0_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc); } +void ggml_gemv_q4_K_4x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK_K; + const int nb = n / qk; + const int ncols_interleaved = 4; + const int blocklen = 4; + assert (n % qk == 0); + assert (nc % ncols_interleaved == 0); + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); +#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) + if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) { + const block_q4_Kx4 *GGML_RESTRICT q4 = (const block_q4_Kx4*) vx; + const uint8x16_t m4b = vdupq_n_u8(0xf); + for (int c = 0; c < nc; c += ncols_interleaved) { + const block_q8_K *GGML_RESTRICT q8 = (const block_q8_K *) vy; + float32x4_t res = vdupq_n_f32(0); + for (int i = 0; i < nb; i++) { + float32x4_t q4_d = vcvt_f32_f16(vld1_f16((const __fp16 *) q4->d)); // d0 d1 d2 d3 + float32x4_t q8_d = vdupq_n_f32(q8->d); + float32x4_t g_d = vmulq_f32 (q4_d, q8_d); + float32x4_t q4_dmin = vcvt_f32_f16(vld1_f16((const __fp16 *) q4->dmin)); // dmin0 dmin1 dmin2 dmin3 + float32x4_t g_dmin = vmulq_f32(q4_dmin, q8_d); + const uint8_t * GGML_RESTRICT q4_ptr = q4->qs; + const int8_t * GGML_RESTRICT q8_ptr = q8->qs; + int32x4_t prod = vdupq_n_s32(0); + const int16x8_t q8_sums = vpaddq_s16(vld1q_s16(q8->bsums), vld1q_s16(q8->bsums + 8)); + // when using vgetq_lane_s16, its index must be a constant, which cannot be used in a loop, so use vst1q_s16 instead. + int16_t tmp_arry[8]; + vst1q_s16(tmp_arry, q8_sums); + for (int j = 0; j < QK_K / 32; ++j) { + int32x4_t sum0 = vdupq_n_s32(0); + int32x4_t sum1 = vdupq_n_s32(0); + // each block: scales0 scales1 scales2 scales3 mins0 mins1 mins2 mins3 + int16x8_t scales_mins = vmovl_s8(vld1_s8((const int8_t *)q4->scales + 8 * j)) ; + prod = vmlal_s16(prod, vdup_n_s16(tmp_arry[j]), vget_high_s16(scales_mins)); + uint8x16_t q4_0 = vld1q_u8((const uint8_t *) q4_ptr); + uint8x16_t q4_1 = vld1q_u8((const uint8_t *) q4_ptr + 16); + uint8x16_t q4_2 = vld1q_u8((const uint8_t *) q4_ptr + 32); + uint8x16_t q4_3 = vld1q_u8((const uint8_t *) q4_ptr + 48); + q4_ptr += 64; + int8x16_t q8_0 = (int8x16_t) vld1q_dup_s64((const int64_t *) q8_ptr); // 8 个 8-bit + int8x16_t q8_1 = (int8x16_t) vld1q_dup_s64((const int64_t *) q8_ptr + 1); + int8x16_t q8_2 = (int8x16_t) vld1q_dup_s64((const int64_t *) q8_ptr + 2); + int8x16_t q8_3 = (int8x16_t) vld1q_dup_s64((const int64_t *) q8_ptr + 3); + q8_ptr += 32; + + /* low bits + (1) sum0 + b0_000 b0_001 b0_002 b0_003 b0_004 b0_005 b0_006 b0_007 | b1_000 b1_001 b1_002 b1_003 b1_004 b1_005 b1_006 b1_007 + * a0 a1 a2 a3 a4 a5 a6 a7 | a0 a1 a2 a3 a4 a5 a6 a7 + |------------dot-------------| |------------dot-------------| | |------------dot-------------| |------------dot-------------| + (2) sum1 + b2_000 b2_001 b2_002 b2_003 b2_004 b2_005 b2_006 b2_007 | b3_000 b3_001 b3_002 b3_003 b3_004 b3_005 b3_006 b3_007 + * a0 a1 a2 a3 a4 a5 a6 a7 | a0 a1 a2 a3 a4 a5 a6 a7 + |------------dot-------------| |------------dot-------------| | |------------dot-------------| |------------dot-------------| + (3) sum0 + b0_008 b0_009 b0_010 b0_011 b0_012 b0_013 b0_014 b0_015 | b1_008 b1_009 b1_010 b1_011 b1_012 b1_013 b1_014 b1_015 + * a8 a9 a10 a11 a12 a13 a14 a15 | a8 a9 a10 a11 a12 a13 a14 a15 + |------------dot-------------| |------------dot-------------| | |------------dot-------------| |------------dot-------------| + (4) sum1 + b2_008 b2_009 b2_010 b2_011 b2_012 b2_013 b2_014 b2_015 | b3_008 b3_009 b3_010 b3_011 b3_012 b3_013 b3_014 b3_015 + * a8 a9 a10 a11 a12 a13 a14 a15 | a8 a9 a10 a11 a12 a13 a14 a15 + |------------dot-------------| |------------dot-------------| | |------------dot-------------| |------------dot-------------| + */ + sum0 = vdotq_s32(sum0, vreinterpretq_s8_u8(vandq_u8(q4_0, m4b)), q8_0); + sum1 = vdotq_s32(sum1, vreinterpretq_s8_u8(vandq_u8(q4_1, m4b)), q8_0); + sum0 = vdotq_s32(sum0, vreinterpretq_s8_u8(vandq_u8(q4_2, m4b)), q8_1); + sum1 = vdotq_s32(sum1, vreinterpretq_s8_u8(vandq_u8(q4_3, m4b)), q8_1); + + /* high bits + (1) sum0 + b0_016 b0_017 b0_018 b0_019 b0_020 b0_021 b0_022 b0_023 | b1_016 b1_017 b1_018 b1_019 b1_020 b1_021 b1_022 b1_023 + * a16 a17 a18 a19 a20 a21 a22 a23 | a16 a17 a18 a19 a20 a21 a22 a23 + |------------dot-------------| |------------dot-------------| | |------------dot-------------| |------------dot-------------| + (2) sum1 + b2_016 b2_017 b2_018 b2_019 b2_020 b2_021 b2_022 b2_023 | b3_016 b3_017 b3_018 b3_019 b3_020 b3_021 b3_022 b3_023 + * a16 a17 a18 a19 a20 a21 a22 a23 | a16 a17 a18 a19 a20 a21 a22 a23 + |------------dot-------------| |------------dot-------------| | |------------dot-------------| |------------dot-------------| + (3) sum0 + b_024 b0_025 b0_026 b0_027 b0_028 b0_029 b0_030 b0_031 | b1_024 b1_025 b1_026 b1_027 b1_028 b1_029 b1_030 b1_031 + * a24 a25 a26 a27 a28 a29 a30 a31 | a24 a25 a26 a27 a28 a29 a30 a31 + |------------dot------------| |------------dot-------------| | |------------dot-------------| |------------dot-------------| + (4) sum1 + b2_024 b2_025 b2_026 b2_027 b2_028 b2_029 b2_030 b2_031 | b3_024 b3_025 b3_026 b3_027 b3_028 b3_029 b3_030 b3_031 + * a24 a25 a26 a27 a28 a29 a30 a31 | a24 a25 a26 a27 a28 a29 a30 a31 + |------------dot------------ | |------------dot-------------| | |------------dot-------------| |------------dot-------------| + */ + sum0 = vdotq_s32(sum0, vreinterpretq_s8_u8(vshrq_n_u8(q4_0, 4)), q8_2); + sum1 = vdotq_s32(sum1, vreinterpretq_s8_u8(vshrq_n_u8(q4_1, 4)), q8_2); + sum0 = vdotq_s32(sum0, vreinterpretq_s8_u8(vshrq_n_u8(q4_2, 4)), q8_3); + sum1 = vdotq_s32(sum1, vreinterpretq_s8_u8(vshrq_n_u8(q4_3, 4)), q8_3); + float32x4_t sumf = vcvtq_f32_s32(vmulq_s32(vmovl_s16(vget_low_s16(scales_mins)), vpaddq_s32(sum0, sum1))); + res = vfmaq_f32(res, g_d, sumf); + } + res -= vmulq_f32(g_dmin, vcvtq_f32_s32(prod)); + q4++; + q8++; + } + vst1q_f32(s, res); + s += ncols_interleaved; + } + return; + } +#else + // todo, c implementation +#endif +} + void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; const int nb = n / qk; @@ -1814,6 +1929,298 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo ggml_gemm_q4_0_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc); } +void ggml_gemm_q4_K_4x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK_K; + const int nb = n / qk; + const int ncols_interleaved = 4; + const int blocklen = 8; // c implementation will use + + assert(n % qk == 0); + assert(nr % 4 == 0); + assert(nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); // row + UNUSED(nc); // column + UNUSED(nb); // block + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if defined(__ARM_FEATURE_MATMUL_INT8) + const block_q8_Kx4 * GGML_RESTRICT q8_ptr_start = (const block_q8_Kx4 *) vy; + const block_q4_Kx4 * GGML_RESTRICT q4_ptr_start = (const block_q4_Kx4 *) vx; + + const uint8x16_t m4b = vdupq_n_u8(0x0f); + float32x4_t zeros = vdupq_n_f32(0.0f); + int anr = nr - nr % 16; + int row = 0; + // Row loop + for (; row < anr / 4; row += 4) { + const block_q8_Kx4 * GGML_RESTRICT q8_ptrs[4]; + q8_ptrs[0] = q8_ptr_start + (row * nb); + for (int i = 0; i < 3; ++i) { + q8_ptrs[i + 1] = q8_ptrs[i] + nb; + } + // Column loop + for (int col = 0; col < nc / ncols_interleaved; col++) { + const block_q4_Kx4 * GGML_RESTRICT q4_ptr = q4_ptr_start + (col * nb); + // init output + float32x4_t res[16]; // final result + for (int i = 0; i < 16; i++) { + res[i] = zeros; + } + // Block loop + for (int64_t b = 0; b < nb; b++) { + float32x4_t g_d[16]; + float32x4_t g_dmin[16]; + int16x8_t q8_bsums[16]; + int32x4_t prod[16]; // store bsums*mins + for (int i = 0; i < 16; i++) { + g_d[i] = zeros; + g_dmin[i] = zeros; + q8_bsums[i] = vdupq_n_s16(0); + prod[i] = vdupq_n_s32(0); + } + // Get global d and dmin + float32x4_t q4_d = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d)); // col0 col1 col2 col3 + float32x4_t q4_dmin = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin)); // dmin0 dmin1 dmin2 dmin3 + int16_t tmp_q8_bsums_array[16][8]; + for (int iter = 0; iter < 4; iter++) { + // Calculation when four lines are grouped together + for (int in = 0; in < 4; in++) { + float32x4_t scalar_q8_d = vdupq_n_f32(q8_ptrs[iter][b].d[in]); + g_d[in + 4 * iter] = vmulq_f32(q4_d, scalar_q8_d); + g_dmin[in + 4 * iter] = vmulq_f32(q4_dmin, scalar_q8_d); + // The 16 elements in each row are merged into 8 elements. No loop expansion is performed here + q8_bsums[in + 4 * iter] = vpaddq_s16(vld1q_s16(q8_ptrs[iter][b].bsums + 16 * in), vld1q_s16(q8_ptrs[iter][b].bsums + 16 * in + 8)); + vst1q_s16(tmp_q8_bsums_array[in + 4 * iter], q8_bsums[in + 4 * iter]); + } + } + // The 256 elements in the superblock are processed in 8 steps + for (int sb = 0; sb < QK_K / 32; sb++) { + int32x4_t acc_rows[16]; // the calculated value of qs + int32x4_t sum[16]; // the value of qs after rearranging + for (int i = 0; i < 16; i++) { + acc_rows[i] = vdupq_n_s32(0); + sum[i] = vdupq_n_s32(0); + } + // each block: scales0 scales1 scales2 scales3 mins0 mins1 mins2 mins3 + int16x8_t scales_mins = vmovl_s8(vld1_s8((const int8_t *) q4_ptr[b].scales + 8 * sb)); + uint8x16_t q4_qs_raw_01_0 = vld1q_u8((const uint8_t *) q4_ptr[b].qs + sb * 64); + uint8x16_t q4_qs_raw_23_0 = vld1q_u8((const uint8_t *) q4_ptr[b].qs + 16 + sb * 64); + uint8x16_t q4_qs_raw_01_1 = vld1q_u8((const uint8_t *) q4_ptr[b].qs + 32 + sb * 64); + uint8x16_t q4_qs_raw_23_1 = vld1q_u8((const uint8_t *) q4_ptr[b].qs + 48 + sb * 64); + + int8x16_t q4_qs_01_l0 = vreinterpretq_s8_u8(vandq_u8(q4_qs_raw_01_0, m4b)); // B0(0-7) B1(0-7) + int8x16_t q4_qs_23_l0 = vreinterpretq_s8_u8(vandq_u8(q4_qs_raw_23_0, m4b)); // B2(0-7) B3(0-7) + int8x16_t q4_qs_01_l1 = vreinterpretq_s8_u8(vandq_u8(q4_qs_raw_01_1, m4b)); // B0(8-15) B1(8-15) + int8x16_t q4_qs_23_l1 = vreinterpretq_s8_u8(vandq_u8(q4_qs_raw_23_1, m4b)); // B2(8-15) B3(8-15) + + int8x16_t q4_qs_01_h0 = vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_raw_01_0, 4)); // B0(16-23) B1(16-23) + int8x16_t q4_qs_23_h0 = vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_raw_23_0, 4)); // B2(16-23) B3(16-23) + int8x16_t q4_qs_01_h1 = vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_raw_01_1, 4)); // B0(24-31) B1(24-31) + int8x16_t q4_qs_23_h1 = vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_raw_23_1, 4)); // B2(24-31) B3(24-31) + + // The 16 rows of the left matrix are expanded four times + for (int iter = 0; iter < 4; iter++) { + // Direct loop unrolling + prod[0 + 4 * iter] = vmlal_s16(prod[0 + 4 * iter], vdup_n_s16(tmp_q8_bsums_array[0 + 4 * iter][sb]), vget_high_s16(scales_mins)); // row(iter): bsums*mins(0-3) + prod[1 + 4 * iter] = vmlal_s16(prod[1 + 4 * iter], vdup_n_s16(tmp_q8_bsums_array[1 + 4 * iter][sb]), vget_high_s16(scales_mins)); // row(iter+1): bsums*mins(0-3) + prod[2 + 4 * iter] = vmlal_s16(prod[2 + 4 * iter], vdup_n_s16(tmp_q8_bsums_array[2 + 4 * iter][sb]), vget_high_s16(scales_mins)); // row(iter+2): bsums*mins(0-3) + prod[3 + 4 * iter] = vmlal_s16(prod[3 + 4 * iter], vdup_n_s16(tmp_q8_bsums_array[3 + 4 * iter][sb]), vget_high_s16(scales_mins)); // row(iter+3): bsums*mins(0-3) + + int8x16_t q8_qs_01_00 = vld1q_s8((const int8_t *) q8_ptrs[iter][b].qs + 128 * sb); // A0(0-7) A1(0-7) + int8x16_t q8_qs_23_00 = vld1q_s8((const int8_t *) q8_ptrs[iter][b].qs + 16 + 128 * sb); // A2(0-7) A3(0-7) + + acc_rows[0 + 4 * iter] = vmmlaq_s32(acc_rows[0 + 4 * iter], q8_qs_01_00, q4_qs_01_l0); // A0*B0(0-7) A0*B1(0-7) A1*B0(0-7) A1*B1(0-7) + acc_rows[1 + 4 * iter] = vmmlaq_s32(acc_rows[1 + 4 * iter], q8_qs_01_00, q4_qs_23_l0); // A0*B2(0-7) A0*B3(0-7) A1*B2(0-7) A1*B3(0-7) + acc_rows[2 + 4 * iter] = vmmlaq_s32(acc_rows[2 + 4 * iter], q8_qs_23_00, q4_qs_01_l0); // A2*B0(0-7) A2*B1(0-7) A3*B0(0-7) A3*B1(0-7) + acc_rows[3 + 4 * iter] = vmmlaq_s32(acc_rows[3 + 4 * iter], q8_qs_23_00, q4_qs_23_l0); // A2*B2(0-7) A2*B3(0-7) A3*B2(0-7) A3*B3(0-7) + + int8x16_t q8_qs_01_01 = vld1q_s8((const int8_t *) q8_ptrs[iter][b].qs + 32 + 128 * sb); // A0(8-15) A1(8-15) + int8x16_t q8_qs_23_01 = vld1q_s8((const int8_t *) q8_ptrs[iter][b].qs + 48 + 128 * sb); // A2(8-15) A3(8-15) + + acc_rows[0 + 4 * iter] = vmmlaq_s32(acc_rows[0 + 4 * iter], q8_qs_01_01, q4_qs_01_l1); // A0*B0(8-15) A0*B1(8-15) A1*B0(8-15) A1*B1(8-15) + acc_rows[1 + 4 * iter] = vmmlaq_s32(acc_rows[1 + 4 * iter], q8_qs_01_01, q4_qs_23_l1); // A0*B2(8-15) A0*B3(8-15) A1*B2(8-15) A1*B3(8-15) + acc_rows[2 + 4 * iter] = vmmlaq_s32(acc_rows[2 + 4 * iter], q8_qs_23_01, q4_qs_01_l1); // A2*B0(8-15) A2*B1(8-15) A3*B0(8-15) A3*B1(8-15) + acc_rows[3 + 4 * iter] = vmmlaq_s32(acc_rows[3 + 4 * iter], q8_qs_23_01, q4_qs_23_l1); // A2*B2(8-15) A2*B3(8-15) A3*B2(8-15) A3*B3(8-15) + + int8x16_t q8_qs_01_02 = vld1q_s8((const int8_t *) q8_ptrs[iter][b].qs + 64 + 128 * sb); // A0(16-23) A1(16-23) + int8x16_t q8_qs_23_02 = vld1q_s8((const int8_t *) q8_ptrs[iter][b].qs + 80 + 128 * sb); // A2(16-23) A3(16-23) + + acc_rows[0 + 4 * iter] = vmmlaq_s32(acc_rows[0 + 4 * iter], q8_qs_01_02, q4_qs_01_h0); // A0*B0(16-23) A0*B1(16-23) A1*B0(16-23) A1*B1(16-23) + acc_rows[1 + 4 * iter] = vmmlaq_s32(acc_rows[1 + 4 * iter], q8_qs_01_02, q4_qs_23_h0); // A0*B2(16-23) A0*B3(16-23) A1*B2(16-23) A1*B3(16-23) + acc_rows[2 + 4 * iter] = vmmlaq_s32(acc_rows[2 + 4 * iter], q8_qs_23_02, q4_qs_01_h0); // A2*B0(16-23) A2*B1(16-23) A3*B0(16-23) A3*B1(16-23) + acc_rows[3 + 4 * iter] = vmmlaq_s32(acc_rows[3 + 4 * iter], q8_qs_23_02, q4_qs_23_h0); // A2*B2(16-23) A2*B3(16-23) A3*B2(16-23) A3*B3(16-23) + + int8x16_t q8_qs_01_03 = vld1q_s8((const int8_t *) q8_ptrs[iter][b].qs + 96 + 128 * sb); // A0(24-31) A1(24-31) + int8x16_t q8_qs_23_03 = vld1q_s8((const int8_t *) q8_ptrs[iter][b].qs + 112 + 128 * sb); // A2(24-31) A3(24-31) + + acc_rows[0 + 4 * iter] = vmmlaq_s32(acc_rows[0 + 4 * iter], q8_qs_01_03, q4_qs_01_h1); // A0*B0(24-31) A0*B1(24-31) A1*B0(24-31) A1*B1(24-31) + acc_rows[1 + 4 * iter] = vmmlaq_s32(acc_rows[1 + 4 * iter], q8_qs_01_03, q4_qs_23_h1); // A0*B2(24-31) A0*B3(24-31) A1*B2(24-31) A1*B3(24-31) + acc_rows[2 + 4 * iter] = vmmlaq_s32(acc_rows[2 + 4 * iter], q8_qs_23_03, q4_qs_01_h1); // A2*B0(24-31) A2*B1(24-31) A3*B0(24-31) A3*B1(24-31) + acc_rows[3 + 4 * iter] = vmmlaq_s32(acc_rows[3 + 4 * iter], q8_qs_23_03, q4_qs_23_h1); // A2*B2(24-31) A2*B3(24-31) A3*B2(24-31) A3*B3(24-31) + + // rearranging vectors + sum[0 + 4 * iter] = vcombine_s32(vget_low_s32(acc_rows[0 + 4 * iter]), vget_low_s32(acc_rows[1 + 4 * iter])); + sum[1 + 4 * iter] = vcombine_s32(vget_high_s32(acc_rows[0 + 4 * iter]), vget_high_s32(acc_rows[1 + 4 * iter])); + sum[2 + 4 * iter] = vcombine_s32(vget_low_s32(acc_rows[2 + 4 * iter]), vget_low_s32(acc_rows[3 + 4 * iter])); + sum[3 + 4 * iter] = vcombine_s32(vget_high_s32(acc_rows[2 + 4 * iter]), vget_high_s32(acc_rows[3 + 4 * iter])); + + float32x4_t sumf_0 = vcvtq_f32_s32(vmulq_s32(vmovl_s16(vget_low_s16(scales_mins)), sum[0 + 4 * iter])); // scales *qs + float32x4_t sumf_1 = vcvtq_f32_s32(vmulq_s32(vmovl_s16(vget_low_s16(scales_mins)), sum[1 + 4 * iter])); + float32x4_t sumf_2 = vcvtq_f32_s32(vmulq_s32(vmovl_s16(vget_low_s16(scales_mins)), sum[2 + 4 * iter])); + float32x4_t sumf_3 = vcvtq_f32_s32(vmulq_s32(vmovl_s16(vget_low_s16(scales_mins)), sum[3 + 4 * iter])); + + res[0 + 4 * iter] = vfmaq_f32(res[0 + 4 * iter], g_d[0 + 4 * iter], sumf_0); + res[1 + 4 * iter] = vfmaq_f32(res[1 + 4 * iter], g_d[1 + 4 * iter], sumf_1); + res[2 + 4 * iter] = vfmaq_f32(res[2 + 4 * iter], g_d[2 + 4 * iter], sumf_2); + res[3 + 4 * iter] = vfmaq_f32(res[3 + 4 * iter], g_d[3 + 4 * iter], sumf_3); + } + } + for (int iter = 0; iter < 4; iter++) { + res[0 + 4 * iter] -= vmulq_f32(g_dmin[0 + 4 * iter], vcvtq_f32_s32(prod[0 + 4 * iter])); + res[1 + 4 * iter] -= vmulq_f32(g_dmin[1 + 4 * iter], vcvtq_f32_s32(prod[1 + 4 * iter])); + res[2 + 4 * iter] -= vmulq_f32(g_dmin[2 + 4 * iter], vcvtq_f32_s32(prod[2 + 4 * iter])); + res[3 + 4 * iter] -= vmulq_f32(g_dmin[3 + 4 * iter], vcvtq_f32_s32(prod[3 + 4 * iter])); + } + } + // store result + for (int i = 0; i < 16; i++) { + vst1q_f32((float *) (s + ((row * 4 + i) * bs + col * 4)), res[i]); + } + } + } + // Handling tail parts that are less than 16 lines + for (; row < nr / 4; row++) { + const block_q8_Kx4 * GGML_RESTRICT q8_ptr = q8_ptr_start + (row * nb); + // Column loop + for (int col = 0; col < nc / ncols_interleaved; col++) { + const block_q4_Kx4 * GGML_RESTRICT q4_ptr = q4_ptr_start + (col * nb); + // init output + float32x4_t res[4]; + for (int i = 0; i < 4; i++) { + res[i] = zeros; + } + // Block loop + for (int64_t b = 0; b < nb; b++) { + float32x4_t g_d[4]; + float32x4_t g_dmin[4]; + int16x8_t q8_bsums[4]; + int32x4_t prod[4]; + for (int i = 0; i < 4; i++) { + g_d[i] = zeros; + g_dmin[i] = zeros; + q8_bsums[i] = vdupq_n_s16(0); + prod[i] = vdupq_n_s32(0); + } + float32x4_t q4_d = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d)); // col0 col1 col2 col3 + float32x4_t q4_dmin = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin)); // dmin0 dmin1 dmin2 dmin3 + int16_t tmp_q8_bsums_array[4][8]; + for (int in = 0; in < 4; in++) { + float32x4_t scalar_q8_d = vdupq_n_f32(q8_ptr[b].d[in]); + g_d[in] = vmulq_f32(q4_d, scalar_q8_d); + g_dmin[in] = vmulq_f32(q4_dmin, scalar_q8_d); + q8_bsums[in] = vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * in), vld1q_s16(q8_ptr[b].bsums + 16 * in + 8)); + vst1q_s16(tmp_q8_bsums_array[in], q8_bsums[in]); + } + for (int sb = 0; sb < QK_K / 32; sb++) { + int32x4_t acc_rows[4]; + int32x4_t sum[4]; + for (int i = 0; i < 4; i++) { + acc_rows[i] = vdupq_n_s32(0); + sum[i] = vdupq_n_s32(0); + } + // each block: scales0 scales1 scales2 scales3 mins0 mins1 mins2 mins3 + int16x8_t scales_mins = vmovl_s8(vld1_s8((const int8_t *) q4_ptr[b].scales + 8 * sb)); + uint8x16_t q4_qs_raw_01_0 = vld1q_u8((const uint8_t *) q4_ptr[b].qs + sb * 64); + uint8x16_t q4_qs_raw_23_0 = vld1q_u8((const uint8_t *) q4_ptr[b].qs + 16 + sb * 64); + uint8x16_t q4_qs_raw_01_1 = vld1q_u8((const uint8_t *) q4_ptr[b].qs + 32 + sb * 64); + uint8x16_t q4_qs_raw_23_1 = vld1q_u8((const uint8_t *) q4_ptr[b].qs + 48 + sb * 64); + + int8x16_t q4_qs_01_l0 = vreinterpretq_s8_u8(vandq_u8(q4_qs_raw_01_0, m4b)); // B0(0-7) B1(0-7) + int8x16_t q4_qs_23_l0 = vreinterpretq_s8_u8(vandq_u8(q4_qs_raw_23_0, m4b)); // B2(0-7) B3(0-7) + int8x16_t q4_qs_01_l1 = vreinterpretq_s8_u8(vandq_u8(q4_qs_raw_01_1, m4b)); // B0(8-15) B1(8-15) + int8x16_t q4_qs_23_l1 = vreinterpretq_s8_u8(vandq_u8(q4_qs_raw_23_1, m4b)); // B2(8-15) B3(8-15) + + prod[0] = vmlal_s16(prod[0], vdup_n_s16(tmp_q8_bsums_array[0][sb]), vget_high_s16(scales_mins)); // row(iter): bsums*mins(0-3) + prod[1] = vmlal_s16(prod[1], vdup_n_s16(tmp_q8_bsums_array[1][sb]), vget_high_s16(scales_mins)); // row(iter+1): bsums*mins(0-3) + prod[2] = vmlal_s16(prod[2], vdup_n_s16(tmp_q8_bsums_array[2][sb]), vget_high_s16(scales_mins)); // row(iter+2): bsums*mins(0-3) + prod[3] = vmlal_s16(prod[3], vdup_n_s16(tmp_q8_bsums_array[3][sb]), vget_high_s16(scales_mins)); // row(iter+3): bsums*mins(0-3) + + int8x16_t q8_qs_01_00 = vld1q_s8((const int8_t *) q8_ptr[b].qs + 128 * sb); // A0(0-7) A1(0-7) + int8x16_t q8_qs_23_00 = vld1q_s8((const int8_t *) q8_ptr[b].qs + 16 + 128 * sb); // A2(0-7) A3(0-7) + + acc_rows[0] = vmmlaq_s32(acc_rows[0], q8_qs_01_00, q4_qs_01_l0); // A0*B0(0-7) A0*B1(0-7) A1*B0(0-7) A1*B1(0-7) + acc_rows[1] = vmmlaq_s32(acc_rows[1], q8_qs_01_00, q4_qs_23_l0); // A0*B2(0-7) A0*B3(0-7) A1*B2(0-7) A1*B3(0-7) + acc_rows[2] = vmmlaq_s32(acc_rows[2], q8_qs_23_00, q4_qs_01_l0); // A2*B0(0-7) A2*B1(0-7) A3*B0(0-7) A3*B1(0-7) + acc_rows[3] = vmmlaq_s32(acc_rows[3], q8_qs_23_00, q4_qs_23_l0); // A2*B2(0-7) A2*B3(0-7) A3*B2(0-7) A3*B3(0-7) + + int8x16_t q8_qs_01_01 = vld1q_s8((const int8_t *) q8_ptr[b].qs + 32 + 128 * sb); // A0(8-15) A1(8-15) + int8x16_t q8_qs_23_01 = vld1q_s8((const int8_t *) q8_ptr[b].qs + 48 + 128 * sb); // A2(8-15) A3(8-15) + + acc_rows[0] = vmmlaq_s32(acc_rows[0], q8_qs_01_01, q4_qs_01_l1); // A0*B0(8-15) A0*B1(8-15) A1*B0(8-15) A1*B1(8-15) + acc_rows[1] = vmmlaq_s32(acc_rows[1], q8_qs_01_01, q4_qs_23_l1); // A0*B2(8-15) A0*B3(8-15) A1*B2(8-15) A1*B3(8-15) + acc_rows[2] = vmmlaq_s32(acc_rows[2], q8_qs_23_01, q4_qs_01_l1); // A2*B0(8-15) A2*B1(8-15) A3*B0(8-15) A3*B1(8-15) + acc_rows[3] = vmmlaq_s32(acc_rows[3], q8_qs_23_01, q4_qs_23_l1); // A2*B2(8-15) A2*B3(8-15) A3*B2(8-15) A3*B3(8-15) + + int8x16_t q4_qs_01_h0 = vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_raw_01_0, 4)); // B0(16-23) B1(16-23) + int8x16_t q4_qs_23_h0 = vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_raw_23_0, 4)); // B2(16-23) B3(16-23) + int8x16_t q4_qs_01_h1 = vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_raw_01_1, 4)); // B0(24-31) B1(24-31) + int8x16_t q4_qs_23_h1 = vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_raw_23_1, 4)); // B2(24-31) B3(24-31) + + int8x16_t q8_qs_01_02 = vld1q_s8((const int8_t *) q8_ptr[b].qs + 64 + 128 * sb); // A0(16-23) A1(16-23) + int8x16_t q8_qs_23_02 = vld1q_s8((const int8_t *) q8_ptr[b].qs + 80 + 128 * sb); // A2(16-23) A3(16-23) + + acc_rows[0] = vmmlaq_s32(acc_rows[0], q8_qs_01_02, q4_qs_01_h0); // A0*B0(16-23) A0*B1(16-23) A1*B0(16-23) A1*B1(16-23) + acc_rows[1] = vmmlaq_s32(acc_rows[1], q8_qs_01_02, q4_qs_23_h0); // A0*B2(16-23) A0*B3(16-23) A1*B2(16-23) A1*B3(16-23) + acc_rows[2] = vmmlaq_s32(acc_rows[2], q8_qs_23_02, q4_qs_01_h0); // A2*B0(16-23) A2*B1(16-23) A3*B0(16-23) A3*B1(16-23) + acc_rows[3] = vmmlaq_s32(acc_rows[3], q8_qs_23_02, q4_qs_23_h0); // A2*B2(16-23) A2*B3(16-23) A3*B2(16-23) A3*B3(16-23) + + int8x16_t q8_qs_01_03 = vld1q_s8((const int8_t *) q8_ptr[b].qs + 96 + 128 * sb); // A0(24-31) A1(24-31) + int8x16_t q8_qs_23_03 = vld1q_s8((const int8_t *) q8_ptr[b].qs + 112 + 128 * sb); // A2(24-31) A3(24-31) + + acc_rows[0] = vmmlaq_s32(acc_rows[0], q8_qs_01_03, q4_qs_01_h1); // A0*B0(24-31) A0*B1(24-31) A1*B0(24-31) A1*B1(24-31) + acc_rows[1] = vmmlaq_s32(acc_rows[1], q8_qs_01_03, q4_qs_23_h1); // A0*B2(24-31) A0*B3(24-31) A1*B2(24-31) A1*B3(24-31) + acc_rows[2] = vmmlaq_s32(acc_rows[2], q8_qs_23_03, q4_qs_01_h1); // A2*B0(24-31) A2*B1(24-31) A3*B0(24-31) A3*B1(24-31) + acc_rows[3] = vmmlaq_s32(acc_rows[3], q8_qs_23_03, q4_qs_23_h1); // A2*B2(24-31) A2*B3(24-31) A3*B2(24-31) A3*B3(24-31) + + // rearranging vectors + sum[0] = vcombine_s32(vget_low_s32(acc_rows[0]), vget_low_s32(acc_rows[1])); + sum[1] = vcombine_s32(vget_high_s32(acc_rows[0]), vget_high_s32(acc_rows[1])); + sum[2] = vcombine_s32(vget_low_s32(acc_rows[2]), vget_low_s32(acc_rows[3])); + sum[3] = vcombine_s32(vget_high_s32(acc_rows[2]), vget_high_s32(acc_rows[3])); + + float32x4_t sumf_0 = vcvtq_f32_s32(vmulq_s32(vmovl_s16(vget_low_s16(scales_mins)), sum[0])); // scales *qs + float32x4_t sumf_1 = vcvtq_f32_s32(vmulq_s32(vmovl_s16(vget_low_s16(scales_mins)), sum[1])); + float32x4_t sumf_2 = vcvtq_f32_s32(vmulq_s32(vmovl_s16(vget_low_s16(scales_mins)), sum[2])); + float32x4_t sumf_3 = vcvtq_f32_s32(vmulq_s32(vmovl_s16(vget_low_s16(scales_mins)), sum[3])); + + res[0] = vfmaq_f32(res[0], g_d[0], sumf_0); + res[1] = vfmaq_f32(res[1], g_d[1], sumf_1); + res[2] = vfmaq_f32(res[2], g_d[2], sumf_2); + res[3] = vfmaq_f32(res[3], g_d[3], sumf_3); + } + res[0] -= vmulq_f32(g_dmin[0], vcvtq_f32_s32(prod[0])); + res[1] -= vmulq_f32(g_dmin[1], vcvtq_f32_s32(prod[1])); + res[2] -= vmulq_f32(g_dmin[2], vcvtq_f32_s32(prod[2])); + res[3] -= vmulq_f32(g_dmin[3], vcvtq_f32_s32(prod[3])); + } + // store result + for (int i = 0; i < 4; i++) { + vst1q_f32((float *) (s + ((row * 4 + i) * bs + col * 4)), res[i]); + } + } + } + return; +#else + // todo, c implementation +#endif +} + void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; const int nb = n / qk; diff --git a/ggml/src/ggml-cpu/arch/x86/repack.cpp b/ggml/src/ggml-cpu/arch/x86/repack.cpp index d95bb6d8aafce..3a521f13c23f4 100644 --- a/ggml/src/ggml-cpu/arch/x86/repack.cpp +++ b/ggml/src/ggml-cpu/arch/x86/repack.cpp @@ -287,230 +287,6 @@ void ggml_quantize_mat_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTR #endif } -void ggml_quantize_mat_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { - assert(QK_K == 256); - assert(k % QK_K == 0); - const int nb = k / QK_K; - - block_q8_Kx4 * GGML_RESTRICT y = (block_q8_Kx4 *) vy; - -#if defined(__AVX2__) - float iscale[4]; - __m256 srcv[4][32]; - __m256 iscale_vec[4]; - - for (int i = 0; i < nb; i++) { - for (int row_iter = 0; row_iter < 4; row_iter++) { - // Load elements into 4 AVX vectors - __m256 v0 = _mm256_loadu_ps( x + row_iter * k + i * 256 ); - __m256 v1 = _mm256_loadu_ps( x + row_iter * k + i * 256 + 8 ); - __m256 v2 = _mm256_loadu_ps( x + row_iter * k + i * 256 + 16 ); - __m256 v3 = _mm256_loadu_ps( x + row_iter * k + i * 256 + 24 ); - - // Compute max(abs(e)) for the block - const __m256 signBit = _mm256_set1_ps( -0.0f ); - __m256 abs0 = _mm256_andnot_ps( signBit, v0 ); - __m256 abs1 = _mm256_andnot_ps( signBit, v1 ); - __m256 abs2 = _mm256_andnot_ps( signBit, v2 ); - __m256 abs3 = _mm256_andnot_ps( signBit, v3 ); - - __m256 maxAbs = _mm256_max_ps( abs0, abs1 ); - maxAbs = _mm256_max_ps( maxAbs, abs2 ); - maxAbs = _mm256_max_ps( maxAbs, abs3 ); - - __m256 mask0 = _mm256_cmp_ps( maxAbs, v0, _CMP_EQ_OQ ); - __m256 mask1 = _mm256_cmp_ps( maxAbs, v1, _CMP_EQ_OQ ); - __m256 mask2 = _mm256_cmp_ps( maxAbs, v2, _CMP_EQ_OQ ); - __m256 mask3 = _mm256_cmp_ps( maxAbs, v3, _CMP_EQ_OQ ); - - __m256 maskAbs = _mm256_or_ps(_mm256_or_ps(mask0, mask1),_mm256_or_ps(mask2, mask3)); - - srcv[row_iter][0] = v0; - srcv[row_iter][1] = v1; - srcv[row_iter][2] = v2; - srcv[row_iter][3] = v3; - - for (int sb = 1; sb < 8; sb++) { - // Temporarily stores absolute quant values - __m256 tempAbs = maxAbs; - - // Load elements into 4 AVX vectors - __m256 v0 = _mm256_loadu_ps( x + row_iter * k + i * 256 + sb * 32); - __m256 v1 = _mm256_loadu_ps( x + row_iter * k + i * 256 + sb * 32 + 8 ); - __m256 v2 = _mm256_loadu_ps( x + row_iter * k + i * 256 + sb * 32 + 16 ); - __m256 v3 = _mm256_loadu_ps( x + row_iter * k + i * 256 + sb * 32 + 24 ); - - // Compute max(abs(e)) for the block - __m256 abs0 = _mm256_andnot_ps( signBit, v0 ); - __m256 abs1 = _mm256_andnot_ps( signBit, v1 ); - __m256 abs2 = _mm256_andnot_ps( signBit, v2 ); - __m256 abs3 = _mm256_andnot_ps( signBit, v3 ); - - maxAbs = _mm256_max_ps( maxAbs, abs0 ); - maxAbs = _mm256_max_ps( maxAbs, abs1 ); - maxAbs = _mm256_max_ps( maxAbs, abs2 ); - maxAbs = _mm256_max_ps( maxAbs, abs3 ); - - __m256 mask_prev = _mm256_cmp_ps( tempAbs, maxAbs, _CMP_EQ_OQ ); - maskAbs = _mm256_and_ps( maskAbs, mask_prev ); - - mask0 = _mm256_cmp_ps( maxAbs, v0, _CMP_EQ_OQ ); - mask1 = _mm256_cmp_ps( maxAbs, v1, _CMP_EQ_OQ ); - mask2 = _mm256_cmp_ps( maxAbs, v2, _CMP_EQ_OQ ); - mask3 = _mm256_cmp_ps( maxAbs, v3, _CMP_EQ_OQ ); - - __m256 mask_curr = _mm256_or_ps(_mm256_or_ps(mask0, mask1),_mm256_or_ps(mask2, mask3)); - maskAbs = _mm256_or_ps(maskAbs, mask_curr); - - srcv[row_iter][sb * 4] = v0; - srcv[row_iter][sb * 4 + 1] = v1; - srcv[row_iter][sb * 4 + 2] = v2; - srcv[row_iter][sb * 4 + 3] = v3; - } - - __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) ); - max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) ); - max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) ); - const float maxScalar = _mm_cvtss_f32( max4 ); - - __m256 maxScalarVec = _mm256_set1_ps(maxScalar); - - __m256 mask_next = _mm256_cmp_ps( maxScalarVec, maxAbs, _CMP_EQ_OQ ); - __m256 finalMask = _mm256_and_ps(maskAbs, mask_next); - - const int mask = _mm256_movemask_ps(finalMask); - iscale[row_iter] = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f; - - if(mask) { - iscale[row_iter] = ( maxScalar != 0.0f ) ? -127.f / maxScalar: 0.0f; - } - - y[i].d[row_iter] = maxScalar ? 1/iscale[row_iter] : 0; - iscale_vec[row_iter] = _mm256_set1_ps(iscale[row_iter]); - } - - __m256i quants_interleaved[32]; - for (int j = 0; j < 32; j++) { - // Apply the multiplier - __m256 v0 = _mm256_mul_ps(srcv[0][j], iscale_vec[0]); - __m256 v1 = _mm256_mul_ps(srcv[1][j], iscale_vec[1]); - __m256 v2 = _mm256_mul_ps(srcv[2][j], iscale_vec[2]); - __m256 v3 = _mm256_mul_ps(srcv[3][j], iscale_vec[3]); - - // Round to nearest integer - v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST ); - v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST ); - v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST ); - v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST ); - - // Convert floats to integers - __m256i i0 = _mm256_cvtps_epi32( v0 ); - __m256i i1 = _mm256_cvtps_epi32( v1 ); - __m256i i2 = _mm256_cvtps_epi32( v2 ); - __m256i i3 = _mm256_cvtps_epi32( v3 ); - - // Convert int32 to int16 - i0 = _mm256_packs_epi32( i0, i1 ); - i2 = _mm256_packs_epi32( i2, i3 ); - // Convert int16 to int8 - i0 = _mm256_packs_epi16( i0, i2 ); - - // Permute and store the quantized weights in the required order after the pack instruction - const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 ); - i0 = _mm256_permutevar8x32_epi32( i0, perm ); - - _mm256_storeu_si256((__m256i *)(y[i].qs + 32 * j), i0); - quants_interleaved[j] = i0; - } - - // Masks to shuffle the quants of corresonding sub blocks for rearraning quants for vectorized bsums computation - __m256i shuffle_mask_sb2 = _mm256_castsi128_si256(_mm_setr_epi8(0, 1, 0, 1, 4, 5, 6, 7, 8, 9, 8, 9, 12, 13, 14, 15)); - shuffle_mask_sb2 = _mm256_permute2f128_si256(shuffle_mask_sb2, shuffle_mask_sb2, 0); - __m256i shuffle_mask_sb3 = _mm256_castsi128_si256(_mm_setr_epi8(0, 1, 2, 3, 0, 1, 6, 7, 8, 9, 10, 11, 8, 9, 14, 15)); - shuffle_mask_sb3 = _mm256_permute2f128_si256(shuffle_mask_sb3, shuffle_mask_sb3, 0); - __m256i shuffle_mask_sb4 = _mm256_castsi128_si256(_mm_setr_epi8(0, 1, 2, 3, 4, 5, 0, 1, 8, 9, 10, 11, 12, 13, 8, 9)); - shuffle_mask_sb4 = _mm256_permute2f128_si256(shuffle_mask_sb4, shuffle_mask_sb4, 0); - - for (int k = 0; k < 4; k++) { - // Quants from four different sub blocks are taken - __m256i q0 = quants_interleaved[k * 8 + 0]; - __m256i q1 = quants_interleaved[k * 8 + 1]; - __m256i q2 = quants_interleaved[k * 8 + 2]; - __m256i q3 = quants_interleaved[k * 8 + 3]; - __m256i q4 = quants_interleaved[k * 8 + 4]; - __m256i q5 = quants_interleaved[k * 8 + 5]; - __m256i q6 = quants_interleaved[k * 8 + 6]; - __m256i q7 = quants_interleaved[k * 8 + 7]; - - - // The below code block has the first half of different sub blocks shuffled and blended so as to process 2 values from each sub block at a time - __m256i sb2_h1_shuffled = _mm256_shuffle_epi8(q2, shuffle_mask_sb2); - __m256i sb_h1_interleaved = _mm256_blend_epi16(q0, sb2_h1_shuffled, 34); - __m256i sb3_h1_shuffled = _mm256_shuffle_epi8(q4, shuffle_mask_sb3); - sb_h1_interleaved = _mm256_blend_epi16(sb_h1_interleaved, sb3_h1_shuffled, 68); - __m256i sb4_h1_shuffled = _mm256_shuffle_epi8(q6, shuffle_mask_sb4); - sb_h1_interleaved = _mm256_blend_epi16(sb_h1_interleaved, sb4_h1_shuffled, 136); - - __m256i one = _mm256_set1_epi8(1); - __m256i bsums_r1 = _mm256_maddubs_epi16(one, sb_h1_interleaved); - - for (int l = 0; l < 3; l++) { - // Quants value shifted to process next two values from each sub block - q0 = _mm256_srli_epi64(q0, 16); - q2 = _mm256_srli_epi64(q2, 16); - q4 = _mm256_srli_epi64(q4, 16); - q6 = _mm256_srli_epi64(q6, 16); - - sb2_h1_shuffled = _mm256_shuffle_epi8(q2, shuffle_mask_sb2); - sb_h1_interleaved = _mm256_blend_epi16(q0, sb2_h1_shuffled, 34); - sb3_h1_shuffled = _mm256_shuffle_epi8(q4, shuffle_mask_sb3); - sb_h1_interleaved = _mm256_blend_epi16(sb_h1_interleaved, sb3_h1_shuffled, 68); - sb4_h1_shuffled = _mm256_shuffle_epi8(q6, shuffle_mask_sb4); - sb_h1_interleaved = _mm256_blend_epi16(sb_h1_interleaved, sb4_h1_shuffled, 136); - - bsums_r1 = _mm256_add_epi16(bsums_r1, _mm256_maddubs_epi16(one, sb_h1_interleaved)); - } - - // The below code block has the second half of different sub blocks shuffled and blended so as to process 2 values from each sub block at a time - __m256i sb2_h2_shuffled = _mm256_shuffle_epi8(q3, shuffle_mask_sb2); - __m256i sb_h2_interleaved = _mm256_blend_epi16(q1, sb2_h2_shuffled, 34); - __m256i sb3_h2_shuffled = _mm256_shuffle_epi8(q5, shuffle_mask_sb3); - sb_h2_interleaved = _mm256_blend_epi16(sb_h2_interleaved, sb3_h2_shuffled, 68); - __m256i sb4_h2_shuffled = _mm256_shuffle_epi8(q7, shuffle_mask_sb4); - sb_h2_interleaved = _mm256_blend_epi16(sb_h2_interleaved, sb4_h2_shuffled, 136); - - __m256i bsums_r2 = _mm256_maddubs_epi16(one, sb_h2_interleaved); - - for (int l = 0; l < 3; l++) { - // Quants value shifted to process next two values from each sub block - q1 = _mm256_srli_epi64(q1, 16); - q3 = _mm256_srli_epi64(q3, 16); - q5 = _mm256_srli_epi64(q5, 16); - q7 = _mm256_srli_epi64(q7, 16); - - sb2_h2_shuffled = _mm256_shuffle_epi8(q3, shuffle_mask_sb2); - sb_h2_interleaved = _mm256_blend_epi16(q1, sb2_h2_shuffled, 34); - sb3_h2_shuffled = _mm256_shuffle_epi8(q5, shuffle_mask_sb3); - sb_h2_interleaved = _mm256_blend_epi16(sb_h2_interleaved, sb3_h2_shuffled, 68); - sb4_h2_shuffled = _mm256_shuffle_epi8(q7, shuffle_mask_sb4); - sb_h2_interleaved = _mm256_blend_epi16(sb_h2_interleaved, sb4_h2_shuffled, 136); - - bsums_r2 = _mm256_add_epi16(bsums_r2, _mm256_maddubs_epi16(one, sb_h2_interleaved)); - } - - // Overall bsums in interleaved fashion computed by adding results of both halves - __m256i bsums_r = _mm256_add_epi16(bsums_r1, bsums_r2); - _mm256_storeu_si256((__m256i *)(y[i].bsums + 16 * k), bsums_r); - } - } - -#else - UNUSED(nb); - UNUSED(y); - ggml_quantize_mat_q8_K_4x8_generic(x, vy, k); -#endif -} - // // GEMV/GEMM templates // diff --git a/ggml/src/ggml-cpu/repack.cpp b/ggml/src/ggml-cpu/repack.cpp index f531d21e23224..05cff94d5d866 100644 --- a/ggml/src/ggml-cpu/repack.cpp +++ b/ggml/src/ggml-cpu/repack.cpp @@ -175,6 +175,346 @@ void ggml_quantize_mat_q8_K_4x8_generic(const float * GGML_RESTRICT x, void * GG } } +void ggml_quantize_mat_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { + assert(QK_K == 256); + assert(k % QK_K == 0); + const int nb = k / QK_K; + + block_q8_Kx4 * GGML_RESTRICT y = (block_q8_Kx4 *) vy; + +#if defined(__ARM_NEON) + const int blck_size_interleave = 8; + float32x4_t srcv[4][64]; // 64 = QK_K/4 + float iscale[4]; + + for (int i = 0; i < nb; i++) { + float32x4_t asrcv[64]; + float32x4_t amaxv[64]; + + // d: + for (int row_iter = 0; row_iter < 4; row_iter++) { + for (int j = 0; j < 64; j++) srcv[row_iter][j] = vld1q_f32(x + row_iter * k + i * 256 + 4 * j); + for (int j = 0; j < 64; j++) asrcv[j] = vabsq_f32(srcv[row_iter][j]); + + for (int j = 0; j < 32; j++) amaxv[2 * j] = vmaxq_f32(asrcv[2 * j], asrcv[2 * j + 1]); + for (int j = 0; j < 16; j++) amaxv[4 * j] = vmaxq_f32(amaxv[4 * j], amaxv[4 * j + 2]); + for (int j = 0; j < 8; j++) amaxv[8 * j] = vmaxq_f32(amaxv[8 * j], amaxv[8 * j + 4]); + for (int j = 0; j < 4; j++) amaxv[16 * j] = vmaxq_f32(amaxv[16 * j], amaxv[16 * j + 8]); + for (int j = 0; j < 2; j++) amaxv[32 * j] = vmaxq_f32(amaxv[32 * j], amaxv[32 * j + 16]); + for (int j = 0; j < 1; j++) amaxv[64 * j] = vmaxq_f32(amaxv[64 * j], amaxv[64 * j + 32]); + + const float amax = vmaxvq_f32(amaxv[0]); + + // Check if exists: orig == amax + float32x4_t amax_vec = vdupq_n_f32(amax); + uint32x4_t mask_all = vdupq_n_u32(0); + for (int j = 0; j < 64; j++) { + uint32x4_t mask_curr = vceqq_f32(amax_vec, srcv[row_iter][j]); + mask_all = vorrq_u32(mask_all, mask_curr); + } + + // Assume that none == amax, then check mask_all to reverse + iscale[row_iter] = ( amax != 0.0f ) ? 127.f / amax : 0.0f; + uint32x4_t cmp = vceqq_u32(mask_all, vdupq_n_u32(0xFFFFFFFFu)); + if (vmaxvq_u32(cmp) != 0) { + iscale[row_iter] = ( amax != 0.0f ) ? -127.f / amax : 0.0f; + } + + y[i].d[row_iter] = amax ? 1/iscale[row_iter] : 0; + } + + // qs: 8 byte interleave over 4 rows, loop = QK_K/8 + // bsums: simply generated one by one, row_i is calculated before row_i+1 + // loops = 16 + for (int j = 0; j < QK_K / blck_size_interleave / 2; j++) { + // Process row0 and row1 + float32x4_t f0_0_3 = vrndnq_f32(vmulq_n_f32(srcv[0][4 * j], iscale[0])); + float32x4_t f0_4_7 = vrndnq_f32(vmulq_n_f32(srcv[0][4 * j + 1], iscale[0])); + float32x4_t f0_8_11 = vrndnq_f32(vmulq_n_f32(srcv[0][4 * j + 2], iscale[0])); + float32x4_t f0_12_15 = vrndnq_f32(vmulq_n_f32(srcv[0][4 * j + 3], iscale[0])); + int32x4_t i0_0_3 = vcvtnq_s32_f32(f0_0_3); + int32x4_t i0_4_7 = vcvtnq_s32_f32(f0_4_7); + int16x8_t i0_0_7 = vcombine_s16(vqmovn_s32(i0_0_3), vqmovn_s32(i0_4_7)); // int32x4 * 2 → int16x4 * 2 → int16x8 + int32x4_t i0_8_11 = vcvtnq_s32_f32(f0_8_11); + int32x4_t i0_12_15 = vcvtnq_s32_f32(f0_12_15); + int16x8_t i0_8_15 = vcombine_s16(vqmovn_s32(i0_8_11), vqmovn_s32(i0_12_15)); // int32x4 * 2 → int16x4 * 2 → int16x8 + + float32x4_t f1_0_3 = vrndnq_f32(vmulq_n_f32(srcv[1][4 * j], iscale[1])); + float32x4_t f1_4_7 = vrndnq_f32(vmulq_n_f32(srcv[1][4 * j + 1], iscale[1])); + float32x4_t f1_8_11 = vrndnq_f32(vmulq_n_f32(srcv[1][4 * j + 2], iscale[1])); + float32x4_t f1_12_15 = vrndnq_f32(vmulq_n_f32(srcv[1][4 * j + 3], iscale[1])); + int32x4_t i1_0_3 = vcvtnq_s32_f32(f1_0_3); + int32x4_t i1_4_7 = vcvtnq_s32_f32(f1_4_7); + int16x8_t i1_0_7 = vcombine_s16(vqmovn_s32(i1_0_3), vqmovn_s32(i1_4_7)); // int32x4 * 2 → int16x4 * 2 → int16x8 + int32x4_t i1_8_11 = vcvtnq_s32_f32(f1_8_11); + int32x4_t i1_12_15 = vcvtnq_s32_f32(f1_12_15); + int16x8_t i1_8_15 = vcombine_s16(vqmovn_s32(i1_8_11), vqmovn_s32(i1_12_15)); // int32x4 * 2 → int16x4 * 2 → int16x8 + + // Calculate and store qs + int8x16_t i0_i1_0_7 = vcombine_s8(vqmovn_s16(i0_0_7), vqmovn_s16(i1_0_7)); // int16x8 * 2 → int8x8 * 2 → int8x16 + int8x16_t i0_i1_8_15 = vcombine_s8(vqmovn_s16(i0_8_15), vqmovn_s16(i1_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16 + vst1q_s8(y[i].qs + 64 * j, i0_i1_0_7); + vst1q_s8(y[i].qs + 64 * j + 32, i0_i1_8_15); + // Calculate and store bsum + int8x16_t i0_0_15 = vcombine_s8(vqmovn_s16(i0_0_7), vqmovn_s16(i0_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16 + int8x16_t i1_0_15 = vcombine_s8(vqmovn_s16(i1_0_7), vqmovn_s16(i1_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16 + y[i].bsums[j] = vaddlvq_s8(i0_0_15); + y[i].bsums[j + 16] = vaddlvq_s8(i1_0_15); + + // Process row2 and row3 + f0_0_3 = vrndnq_f32(vmulq_n_f32(srcv[2][4 * j], iscale[2])); + f0_4_7 = vrndnq_f32(vmulq_n_f32(srcv[2][4 * j + 1], iscale[2])); + f0_8_11 = vrndnq_f32(vmulq_n_f32(srcv[2][4 * j + 2], iscale[2])); + f0_12_15 = vrndnq_f32(vmulq_n_f32(srcv[2][4 * j + 3], iscale[2])); + i0_0_3 = vcvtnq_s32_f32(f0_0_3); + i0_4_7 = vcvtnq_s32_f32(f0_4_7); + i0_0_7 = vcombine_s16(vqmovn_s32(i0_0_3), vqmovn_s32(i0_4_7)); // int32x4 * 2 → int16x4 * 2 → int16x8 + i0_8_11 = vcvtnq_s32_f32(f0_8_11); + i0_12_15 = vcvtnq_s32_f32(f0_12_15); + i0_8_15 = vcombine_s16(vqmovn_s32(i0_8_11), vqmovn_s32(i0_12_15)); // int32x4 * 2 → int16x4 * 2 → int16x8 + + f1_0_3 = vrndnq_f32(vmulq_n_f32(srcv[3][4 * j], iscale[3])); + f1_4_7 = vrndnq_f32(vmulq_n_f32(srcv[3][4 * j + 1], iscale[3])); + f1_8_11 = vrndnq_f32(vmulq_n_f32(srcv[3][4 * j + 2], iscale[3])); + f1_12_15 = vrndnq_f32(vmulq_n_f32(srcv[3][4 * j + 3], iscale[3])); + i1_0_3 = vcvtnq_s32_f32(f1_0_3); + i1_4_7 = vcvtnq_s32_f32(f1_4_7); + i1_0_7 = vcombine_s16(vqmovn_s32(i1_0_3), vqmovn_s32(i1_4_7)); // int32x4 * 2 → int16x4 * 2 → int16x8 + i1_8_11 = vcvtnq_s32_f32(f1_8_11); + i1_12_15 = vcvtnq_s32_f32(f1_12_15); + i1_8_15 = vcombine_s16(vqmovn_s32(i1_8_11), vqmovn_s32(i1_12_15)); // int32x4 * 2 → int16x4 * 2 → int16x8 + + // Calculate and store qs + i0_i1_0_7 = vcombine_s8(vqmovn_s16(i0_0_7), vqmovn_s16(i1_0_7)); // int16x8 * 2 → int8x8 * 2 → int8x16 + i0_i1_8_15 = vcombine_s8(vqmovn_s16(i0_8_15), vqmovn_s16(i1_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16 + vst1q_s8(y[i].qs + 64 * j + 16, i0_i1_0_7); + vst1q_s8(y[i].qs + 64 * j + 48, i0_i1_8_15); + // Calculate and store bsum + i0_0_15 = vcombine_s8(vqmovn_s16(i0_0_7), vqmovn_s16(i0_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16 + i1_0_15 = vcombine_s8(vqmovn_s16(i1_0_7), vqmovn_s16(i1_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16 + y[i].bsums[j + 32] = vaddlvq_s8(i0_0_15); + y[i].bsums[j + 48] = vaddlvq_s8(i1_0_15); + } + } +#elif defined(__AVX2__) + float iscale[4]; + __m256 srcv[4][32]; + __m256 iscale_vec[4]; + + for (int i = 0; i < nb; i++) { + for (int row_iter = 0; row_iter < 4; row_iter++) { + // Load elements into 4 AVX vectors + __m256 v0 = _mm256_loadu_ps( x + row_iter * k + i * 256 ); + __m256 v1 = _mm256_loadu_ps( x + row_iter * k + i * 256 + 8 ); + __m256 v2 = _mm256_loadu_ps( x + row_iter * k + i * 256 + 16 ); + __m256 v3 = _mm256_loadu_ps( x + row_iter * k + i * 256 + 24 ); + + // Compute max(abs(e)) for the block + const __m256 signBit = _mm256_set1_ps( -0.0f ); + __m256 abs0 = _mm256_andnot_ps( signBit, v0 ); + __m256 abs1 = _mm256_andnot_ps( signBit, v1 ); + __m256 abs2 = _mm256_andnot_ps( signBit, v2 ); + __m256 abs3 = _mm256_andnot_ps( signBit, v3 ); + + __m256 maxAbs = _mm256_max_ps( abs0, abs1 ); + maxAbs = _mm256_max_ps( maxAbs, abs2 ); + maxAbs = _mm256_max_ps( maxAbs, abs3 ); + + __m256 mask0 = _mm256_cmp_ps( maxAbs, v0, _CMP_EQ_OQ ); + __m256 mask1 = _mm256_cmp_ps( maxAbs, v1, _CMP_EQ_OQ ); + __m256 mask2 = _mm256_cmp_ps( maxAbs, v2, _CMP_EQ_OQ ); + __m256 mask3 = _mm256_cmp_ps( maxAbs, v3, _CMP_EQ_OQ ); + + __m256 maskAbs = _mm256_or_ps(_mm256_or_ps(mask0, mask1),_mm256_or_ps(mask2, mask3)); + + srcv[row_iter][0] = v0; + srcv[row_iter][1] = v1; + srcv[row_iter][2] = v2; + srcv[row_iter][3] = v3; + + for (int sb = 1; sb < 8; sb++) { + // Temporarily stores absolute quant values + __m256 tempAbs = maxAbs; + + // Load elements into 4 AVX vectors + __m256 v0 = _mm256_loadu_ps( x + row_iter * k + i * 256 + sb * 32); + __m256 v1 = _mm256_loadu_ps( x + row_iter * k + i * 256 + sb * 32 + 8 ); + __m256 v2 = _mm256_loadu_ps( x + row_iter * k + i * 256 + sb * 32 + 16 ); + __m256 v3 = _mm256_loadu_ps( x + row_iter * k + i * 256 + sb * 32 + 24 ); + + // Compute max(abs(e)) for the block + __m256 abs0 = _mm256_andnot_ps( signBit, v0 ); + __m256 abs1 = _mm256_andnot_ps( signBit, v1 ); + __m256 abs2 = _mm256_andnot_ps( signBit, v2 ); + __m256 abs3 = _mm256_andnot_ps( signBit, v3 ); + + maxAbs = _mm256_max_ps( maxAbs, abs0 ); + maxAbs = _mm256_max_ps( maxAbs, abs1 ); + maxAbs = _mm256_max_ps( maxAbs, abs2 ); + maxAbs = _mm256_max_ps( maxAbs, abs3 ); + + __m256 mask_prev = _mm256_cmp_ps( tempAbs, maxAbs, _CMP_EQ_OQ ); + maskAbs = _mm256_and_ps( maskAbs, mask_prev ); + + mask0 = _mm256_cmp_ps( maxAbs, v0, _CMP_EQ_OQ ); + mask1 = _mm256_cmp_ps( maxAbs, v1, _CMP_EQ_OQ ); + mask2 = _mm256_cmp_ps( maxAbs, v2, _CMP_EQ_OQ ); + mask3 = _mm256_cmp_ps( maxAbs, v3, _CMP_EQ_OQ ); + + __m256 mask_curr = _mm256_or_ps(_mm256_or_ps(mask0, mask1),_mm256_or_ps(mask2, mask3)); + maskAbs = _mm256_or_ps(maskAbs, mask_curr); + + srcv[row_iter][sb * 4] = v0; + srcv[row_iter][sb * 4 + 1] = v1; + srcv[row_iter][sb * 4 + 2] = v2; + srcv[row_iter][sb * 4 + 3] = v3; + } + + __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) ); + max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) ); + max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) ); + const float maxScalar = _mm_cvtss_f32( max4 ); + + __m256 maxScalarVec = _mm256_set1_ps(maxScalar); + + __m256 mask_next = _mm256_cmp_ps( maxScalarVec, maxAbs, _CMP_EQ_OQ ); + __m256 finalMask = _mm256_and_ps(maskAbs, mask_next); + + const int mask = _mm256_movemask_ps(finalMask); + iscale[row_iter] = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f; + + if(mask) { + iscale[row_iter] = ( maxScalar != 0.0f ) ? -127.f / maxScalar: 0.0f; + } + + y[i].d[row_iter] = maxScalar ? 1/iscale[row_iter] : 0; + iscale_vec[row_iter] = _mm256_set1_ps(iscale[row_iter]); + } + + __m256i quants_interleaved[32]; + for (int j = 0; j < 32; j++) { + // Apply the multiplier + __m256 v0 = _mm256_mul_ps(srcv[0][j], iscale_vec[0]); + __m256 v1 = _mm256_mul_ps(srcv[1][j], iscale_vec[1]); + __m256 v2 = _mm256_mul_ps(srcv[2][j], iscale_vec[2]); + __m256 v3 = _mm256_mul_ps(srcv[3][j], iscale_vec[3]); + + // Round to nearest integer + v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST ); + v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST ); + v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST ); + v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST ); + + // Convert floats to integers + __m256i i0 = _mm256_cvtps_epi32( v0 ); + __m256i i1 = _mm256_cvtps_epi32( v1 ); + __m256i i2 = _mm256_cvtps_epi32( v2 ); + __m256i i3 = _mm256_cvtps_epi32( v3 ); + + // Convert int32 to int16 + i0 = _mm256_packs_epi32( i0, i1 ); + i2 = _mm256_packs_epi32( i2, i3 ); + // Convert int16 to int8 + i0 = _mm256_packs_epi16( i0, i2 ); + + // Permute and store the quantized weights in the required order after the pack instruction + const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 ); + i0 = _mm256_permutevar8x32_epi32( i0, perm ); + + _mm256_storeu_si256((__m256i *)(y[i].qs + 32 * j), i0); + quants_interleaved[j] = i0; + } + + // Masks to shuffle the quants of corresonding sub blocks for rearraning quants for vectorized bsums computation + __m256i shuffle_mask_sb2 = _mm256_castsi128_si256(_mm_setr_epi8(0, 1, 0, 1, 4, 5, 6, 7, 8, 9, 8, 9, 12, 13, 14, 15)); + shuffle_mask_sb2 = _mm256_permute2f128_si256(shuffle_mask_sb2, shuffle_mask_sb2, 0); + __m256i shuffle_mask_sb3 = _mm256_castsi128_si256(_mm_setr_epi8(0, 1, 2, 3, 0, 1, 6, 7, 8, 9, 10, 11, 8, 9, 14, 15)); + shuffle_mask_sb3 = _mm256_permute2f128_si256(shuffle_mask_sb3, shuffle_mask_sb3, 0); + __m256i shuffle_mask_sb4 = _mm256_castsi128_si256(_mm_setr_epi8(0, 1, 2, 3, 4, 5, 0, 1, 8, 9, 10, 11, 12, 13, 8, 9)); + shuffle_mask_sb4 = _mm256_permute2f128_si256(shuffle_mask_sb4, shuffle_mask_sb4, 0); + + for (int k = 0; k < 4; k++) { + // Quants from four different sub blocks are taken + __m256i q0 = quants_interleaved[k * 8 + 0]; + __m256i q1 = quants_interleaved[k * 8 + 1]; + __m256i q2 = quants_interleaved[k * 8 + 2]; + __m256i q3 = quants_interleaved[k * 8 + 3]; + __m256i q4 = quants_interleaved[k * 8 + 4]; + __m256i q5 = quants_interleaved[k * 8 + 5]; + __m256i q6 = quants_interleaved[k * 8 + 6]; + __m256i q7 = quants_interleaved[k * 8 + 7]; + + + // The below code block has the first half of different sub blocks shuffled and blended so as to process 2 values from each sub block at a time + __m256i sb2_h1_shuffled = _mm256_shuffle_epi8(q2, shuffle_mask_sb2); + __m256i sb_h1_interleaved = _mm256_blend_epi16(q0, sb2_h1_shuffled, 34); + __m256i sb3_h1_shuffled = _mm256_shuffle_epi8(q4, shuffle_mask_sb3); + sb_h1_interleaved = _mm256_blend_epi16(sb_h1_interleaved, sb3_h1_shuffled, 68); + __m256i sb4_h1_shuffled = _mm256_shuffle_epi8(q6, shuffle_mask_sb4); + sb_h1_interleaved = _mm256_blend_epi16(sb_h1_interleaved, sb4_h1_shuffled, 136); + + __m256i one = _mm256_set1_epi8(1); + __m256i bsums_r1 = _mm256_maddubs_epi16(one, sb_h1_interleaved); + + for (int l = 0; l < 3; l++) { + // Quants value shifted to process next two values from each sub block + q0 = _mm256_srli_epi64(q0, 16); + q2 = _mm256_srli_epi64(q2, 16); + q4 = _mm256_srli_epi64(q4, 16); + q6 = _mm256_srli_epi64(q6, 16); + + sb2_h1_shuffled = _mm256_shuffle_epi8(q2, shuffle_mask_sb2); + sb_h1_interleaved = _mm256_blend_epi16(q0, sb2_h1_shuffled, 34); + sb3_h1_shuffled = _mm256_shuffle_epi8(q4, shuffle_mask_sb3); + sb_h1_interleaved = _mm256_blend_epi16(sb_h1_interleaved, sb3_h1_shuffled, 68); + sb4_h1_shuffled = _mm256_shuffle_epi8(q6, shuffle_mask_sb4); + sb_h1_interleaved = _mm256_blend_epi16(sb_h1_interleaved, sb4_h1_shuffled, 136); + + bsums_r1 = _mm256_add_epi16(bsums_r1, _mm256_maddubs_epi16(one, sb_h1_interleaved)); + } + + // The below code block has the second half of different sub blocks shuffled and blended so as to process 2 values from each sub block at a time + __m256i sb2_h2_shuffled = _mm256_shuffle_epi8(q3, shuffle_mask_sb2); + __m256i sb_h2_interleaved = _mm256_blend_epi16(q1, sb2_h2_shuffled, 34); + __m256i sb3_h2_shuffled = _mm256_shuffle_epi8(q5, shuffle_mask_sb3); + sb_h2_interleaved = _mm256_blend_epi16(sb_h2_interleaved, sb3_h2_shuffled, 68); + __m256i sb4_h2_shuffled = _mm256_shuffle_epi8(q7, shuffle_mask_sb4); + sb_h2_interleaved = _mm256_blend_epi16(sb_h2_interleaved, sb4_h2_shuffled, 136); + + __m256i bsums_r2 = _mm256_maddubs_epi16(one, sb_h2_interleaved); + + for (int l = 0; l < 3; l++) { + // Quants value shifted to process next two values from each sub block + q1 = _mm256_srli_epi64(q1, 16); + q3 = _mm256_srli_epi64(q3, 16); + q5 = _mm256_srli_epi64(q5, 16); + q7 = _mm256_srli_epi64(q7, 16); + + sb2_h2_shuffled = _mm256_shuffle_epi8(q3, shuffle_mask_sb2); + sb_h2_interleaved = _mm256_blend_epi16(q1, sb2_h2_shuffled, 34); + sb3_h2_shuffled = _mm256_shuffle_epi8(q5, shuffle_mask_sb3); + sb_h2_interleaved = _mm256_blend_epi16(sb_h2_interleaved, sb3_h2_shuffled, 68); + sb4_h2_shuffled = _mm256_shuffle_epi8(q7, shuffle_mask_sb4); + sb_h2_interleaved = _mm256_blend_epi16(sb_h2_interleaved, sb4_h2_shuffled, 136); + + bsums_r2 = _mm256_add_epi16(bsums_r2, _mm256_maddubs_epi16(one, sb_h2_interleaved)); + } + + // Overall bsums in interleaved fashion computed by adding results of both halves + __m256i bsums_r = _mm256_add_epi16(bsums_r1, bsums_r2); + _mm256_storeu_si256((__m256i *)(y[i].bsums + 16 * k), bsums_r); + } + } + +#else + // NOTE: This default c implementation is aligned with AVX2 implemantation, but differs from arm implementation. + // especially in bsums arrangement. + UNUSED(nb); + UNUSED(y); + ggml_quantize_mat_q8_K_4x8_generic(x, vy, k); +#endif +} + } // extern "C" template @@ -1078,6 +1418,90 @@ static block_q4_0x8 make_block_q4_0x8(block_q4_0 * in, unsigned int blck_size_in return out; } +static void make_block_q4_Kx4(block_q4_K * in, unsigned int blck_size_interleave, block_q4_Kx4 * out) { + int nrow = 4; + int nloop = 4; + + // d and dmin values of the 4 Q4_K are copied directly. + for (int i = 0; i < nrow; i++) { + out->d[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d; + } + + for (int i = 0; i < nrow; i++) { + out->dmin[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.dmin; + } + + // For qs, 2 things need to be done: + // 1. Recover from Q4_K storage tyle to Q4_0 style + // 2. Interleave quants by taking 8 bytes at a time + + // 1. + const uint64_t lo_mask = 0x0f0f0f0f0f0f0f0fULL; + const uint64_t hi_mask = 0xf0f0f0f0f0f0f0f0ULL; + for (int i = 0; i < nrow; i++) { + uint64_t *q = (uint64_t *)(in[i].qs); + for (int j = 0; j < nloop; j++) { + uint64_t q0, q1, q2, q3; + q0 = q[0]; + q1 = q[1]; + q2 = q[2]; + q3 = q[3]; + + uint64_t hi1, hi2, lo3, lo4; + hi1 = q0 & hi_mask; + hi2 = q1 & hi_mask; + lo3 = q2 & lo_mask; + lo4 = q3 & lo_mask; + q[0] = (q0 & lo_mask) | (lo3 << 4); + q[1] = (q1 & lo_mask) | (lo4 << 4); + q[2] = (q2 & hi_mask) | (hi1 >> 4); + q[3] = (q3 & hi_mask) | (hi2 >> 4); + + q += 4; + } + } + + // 2. + // Calculate total number of interleaved subblocks + const int end = QK_K * 2 / blck_size_interleave; + uint64_t *src, *dst; + for (int i = 0; i < end; ++i) { + int src_id = i % 4; + int src_offset = (i / 4) * blck_size_interleave; + int dst_offset = i * blck_size_interleave; + + src = (uint64_t *)(&in[src_id].qs[src_offset]); + dst = (uint64_t *)(&out->qs[dst_offset]); + *dst = *src; + } + + // For scales & mins of each subblock. (8 subblocks in one Q4_K, 32 in total) + // A special requirement to meet: expand to 8-bit from 6-bit. + static const uint32_t kmask1 = 0x3f3f3f3f; + static const uint32_t kmask2 = 0x0f0f0f0f; + static const uint32_t kmask3 = 0x03030303; + uint32_t utmp[4]; + for (int i = 0; i < nrow; i++) { + // rearrange as d|d|...|d|min|min|...|min + // expand to 8-bit from 6-bit + memset(utmp, 0, 16); + memcpy(utmp, in[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; + + // move to Q4_K + const uint8_t * d_ptr = (const uint8_t*)&utmp[0]; + const uint8_t * m_ptr = (const uint8_t*)&utmp[2]; + for (int j = 0; j < 8; j++) { + out->scales[j * 8 + i] = *d_ptr++; + out->scales[j * 8 + i + nrow] = *m_ptr++; + } + } +} + static block_q4_Kx8 make_block_q4_Kx8(block_q4_K * in, unsigned int blck_size_interleave) { block_q4_Kx8 out; //Delta(scale) and dmin values of the eight Q4_K structures are copied onto the output interleaved structure @@ -1228,6 +1652,47 @@ static int repack_q4_0_to_q4_0_4_bl(struct ggml_tensor * t, int interleave_block GGML_UNUSED(data_size); } + +static int repack_q4_K_to_q4_K_4_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q4_K); + GGML_ASSERT(interleave_block == 8); + constexpr int nrows_interleaved = 4; + + block_q4_Kx4 * dst = (block_q4_Kx4 *)t->data; + const block_q4_K * src = (const block_q4_K *) data; + block_q4_K dst_tmp[4]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK_K; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_K)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++ ) { + dst_tmp[i] = src[x + i * nblocks]; + } + make_block_q4_Kx4(dst_tmp, interleave_block, dst++); + } + src += nrows_interleaved * nblocks; + } + + // change tensor shape as block_q4_kx4 brings space size change + //t->nb[0] = ggml_type_size(type); + t->nb[0] = sizeof(block_q4_Kx4) / 4; + t->nb[1] = t->nb[0] * (t->ne[0] / ggml_blck_size(t->type)); + for (int i = 2; i < GGML_MAX_DIMS; i++) { + t->nb[i] = t->nb[i - 1] * t->ne[i - 1]; + } + + return 0; + + GGML_UNUSED(data_size); +} + static int repack_q4_K_to_q4_K_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) { GGML_ASSERT(t->type == GGML_TYPE_Q4_K); GGML_ASSERT(interleave_block == 8); @@ -1464,6 +1929,10 @@ template <> int repack(struct ggml_tensor * t, const void * da return repack_q4_0_to_q4_0_8_bl(t, 8, data, data_size); } +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_q4_K_to_q4_K_4_bl(t, 8, data, data_size); +} + template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { return repack_q4_K_to_q4_K_8_bl(t, 8, data, data_size); } @@ -1501,6 +1970,10 @@ template <> void gemv(int n, float * s, size_t ggml_gemv_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc); } +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_q4_K_4x8_q8_K(n, s, bs, vx, vy, nr, nc); +} + template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemv_q4_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc); } @@ -1533,6 +2006,10 @@ template <> void gemm(int n, float * s, size_t ggml_gemm_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc); } +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_q4_K_4x8_q8_K(n, s, bs, vx, vy, nr, nc); +} + template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemm_q4_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc); } @@ -1685,7 +2162,7 @@ template from_float; // we don't support permuted src0 or src1 - GGML_ASSERT(nb00 == ggml_type_size(src0->type)); + //GGML_ASSERT(nb00 == ggml_type_size(src0->type)); GGML_ASSERT(nb10 == ggml_type_size(src1->type)); // dst cannot be transposed or permuted @@ -1816,6 +2293,7 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons static const ggml::cpu::repack::tensor_traits q4_0_4x4_q8_0; static const ggml::cpu::repack::tensor_traits q4_0_4x8_q8_0; static const ggml::cpu::repack::tensor_traits q4_0_8x8_q8_0; + static const ggml::cpu::repack::tensor_traits q4_K_4x8_q8_K; // new for ARM N2 static const ggml::cpu::repack::tensor_traits q4_K_8x8_q8_K; // instance for Q2 @@ -1846,6 +2324,10 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons if (cur->ne[1] % 8 == 0) { return &q4_K_8x8_q8_K; } + } else if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) { // new for ARM N2 + if (cur->ne[1] % 4 == 0) { + return &q4_K_4x8_q8_K; + } } } else if (cur->type == GGML_TYPE_Q2_K) { if (ggml_cpu_has_avx512()) { @@ -1915,6 +2397,26 @@ static size_t ggml_backend_cpu_repack_buffer_type_get_alignment(ggml_backend_buf GGML_UNUSED(buft); } +// size calculation after q4_kx4 repacking, it's different from traditional type +size_t ggml_nbytes_q4_kx4(const struct ggml_tensor * tensor) { + size_t nbytes; + const size_t blck_size = 256; + const size_t type_size = sizeof(block_q4_Kx4) / 4; + nbytes = ((tensor->ne[0] * type_size) / blck_size) * tensor->ne[1] * tensor->ne[2] * tensor->ne[3]; + return nbytes; +} + +static size_t ggml_backend_cpu_aarch64_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor) { + if (tensor->type == GGML_TYPE_Q4_K) { + if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) { + return ggml_nbytes_q4_kx4(tensor); + } + } + return ggml_nbytes(tensor); + + GGML_UNUSED(buft); +} + namespace ggml::cpu::repack { class extra_buffer_type : ggml::cpu::extra_buffer_type { bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override { @@ -1971,7 +2473,7 @@ ggml_backend_buffer_type_t ggml_backend_cpu_repack_buffer_type(void) { /* .alloc_buffer = */ ggml_backend_cpu_repack_buffer_type_alloc_buffer, /* .get_alignment = */ ggml_backend_cpu_repack_buffer_type_get_alignment, /* .get_max_size = */ nullptr, // defaults to SIZE_MAX - /* .get_alloc_size = */ nullptr, // defaults to ggml_nbytes + /* .get_alloc_size = */ ggml_backend_cpu_aarch64_buffer_type_get_alloc_size, // defaults to ggml_nbytes except for ARM N2 /* .is_host = */ nullptr, }, /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0), diff --git a/ggml/src/ggml-cpu/repack.h b/ggml/src/ggml-cpu/repack.h index cb32b503d3a11..3b10523bd55aa 100644 --- a/ggml/src/ggml-cpu/repack.h +++ b/ggml/src/ggml-cpu/repack.h @@ -36,6 +36,40 @@ using block_q4_0x8 = block<4, 8>; using block_q8_0x4 = block<8, 4>; using block_q8_0x8 = block<8, 8>; +struct block_q4_Kx4 { + ggml_half d[4]; // super-block scale for quantized scales + ggml_half dmin[4]; // super-block scale for quantized mins + int8_t scales[64]; // scales and mins, quantized with 8 bits (recover from 6-bit during repack) TODO: consider if uint8_t? + uint8_t qs[512]; // 4--bit quants + + /********************************************************layout***************************************************************/ + // low <-------------------------------------------------------------------------------------> high + // + // d: |s0|s1|s2|s3| + // + // dmin: |s0|s1|s2|s3| + // + // scales: |-------- d --------|-------- m --------| + // |s0b0|s1b0|s2b0|s3b0|s0b0|s1b0|s2b0|s3b0| + // |s0b1|s1b1|s2b1|s3b1|s0b1|s1b1|s2b1|s3b1| + // ...... + // |s0b7|s1b7|s2b7|s3b7|s0b7|s1b7|s2b7|s3b7| + // + // qs: + // |s0w0 |s0w16|s0w1 |s0w17|s0w2 |s0w18|s0w3 |s0w19|s0w4 |s0w20|s0w5 |s0w21|s0w6 |s0w22|s0w7 |s0w23| --- 8B from s0 + // |s1w0 |s1w16|s1w1 |s1w17|s1w2 |s1w18|s1w3 |s1w19|s1w4 |s1w20|s1w5 |s1w21|s1w6 |s1w22|s1w7 |s1w23| --- 8B from s1 + // |s2w0 |s2w16|s2w1 |s2w17|s2w2 |s2w18|s2w3 |s2w19|s2w4 |s2w20|s2w5 |s2w21|s2w6 |s2w22|s2w7 |s2w23| --- 8B from s2 + // |s3w0 |s3w16|s3w1 |s3w17|s3w2 |s3w18|s3w3 |s3w19|s3w4 |s3w20|s3w5 |s3w21|s3w6 |s3w22|s3w7 |s3w23| --- 8B from s3 + // |s0w8 |s0w24|s0w9 |s0w25|s0w10|s0w26|s0w11|s0w27|s0w12|s0w28|s0w13|s0w29|s0w14|s0w30|s0w15|s0w31| --- 8B from s0 + // |s1w8 |s1w24|s1w9 |s1w25|s1w10|s1w26|s1w11|s1w27|s1w12|s1w28|s1w13|s1w29|s1w14|s1w30|s1w15|s1w31| --- 8B from s1 + // |s2w8 |s2w24|s2w9 |s2w25|s2w10|s2w26|s2w11|s2w27|s2w12|s2w28|s2w13|s2w29|s2w14|s2w30|s2w15|s2w31| --- 8B from s2 + // |s3w8 |s3w24|s3w9 |s3w25|s3w10|s3w26|s3w11|s3w27|s3w12|s3w28|s3w13|s3w29|s3w14|s3w30|s3w15|s3w31| --- 8B from s3 + // + // ...... + // + /*****************************************************************************************************************************/ +}; + struct block_q4_Kx8 { ggml_half d[8]; // super-block scale for quantized scales ggml_half dmin[8]; // super-block scale for quantized mins @@ -84,6 +118,7 @@ void ggml_quantize_mat_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTR void ggml_gemv_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q4_K_4x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); @@ -91,6 +126,7 @@ void ggml_gemv_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q4_K_4x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); @@ -104,12 +140,14 @@ void ggml_gemv_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, void ggml_gemv_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +// gemv_generic ??? void ggml_gemv_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +// gemm_generic ??? void ggml_gemm_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);