Skip to content

Commit 28e30c2

Browse files
committed
Guard gemm with proper features, improved superblock scale and min calc
Signed-off-by: Alberto Cabrera <[email protected]>
1 parent f9e1527 commit 28e30c2

File tree

1 file changed

+48
-45
lines changed

1 file changed

+48
-45
lines changed

ggml/src/ggml-cpu/arch/arm/repack.cpp

Lines changed: 48 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,30 @@
2424

2525
#define UNUSED GGML_UNUSED
2626

27+
static inline void decode_q4_Kx8_scales_mins(const uint8_t * scales_in,
28+
int16x8_t * out_mins,
29+
int8_t * out_scales) {
30+
constexpr uint32_t kmask1 = 0x3f3f3f3f;
31+
constexpr uint32_t kmask2 = 0x0f0f0f0f;
32+
constexpr uint32_t kmask3 = 0x03030303;
33+
constexpr uint8_t scales_size = 12;
34+
35+
uint32_t sm[3];
36+
memcpy(sm, scales_in, scales_size);
37+
38+
const uint32_t mins_0_3 = sm[1] & kmask1;
39+
const uint32_t mins_4_7 = ((sm[2] >> 4) & kmask2) | (((sm[1] >> 6) & kmask3) << 4);
40+
const uint32x2_t mins_u32 = { mins_0_3, mins_4_7 };
41+
42+
*out_mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins_u32)));
43+
44+
uint32_t scales_u32[2];
45+
scales_u32[0] = sm[0] & kmask1;
46+
scales_u32[1] = (sm[2] & kmask2) | (((sm[0] >> 6) & kmask3) << 4);
47+
memcpy(out_scales, scales_u32, 8);
48+
}
49+
50+
2751
void ggml_quantize_mat_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
2852
assert(QK8_0 == 32);
2953
assert(k % QK8_0 == 0);
@@ -1890,29 +1914,6 @@ void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const
18901914
ggml_gemm_iq4_nl_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);
18911915
}
18921916

1893-
static inline void decode_q4_Kx8_scales_mins(const uint8_t * scales_in,
1894-
int16x8_t * out_mins,
1895-
int8_t * out_scales) {
1896-
constexpr uint32_t kmask1 = 0x3f3f3f3f;
1897-
constexpr uint32_t kmask2 = 0x0f0f0f0f;
1898-
constexpr uint32_t kmask3 = 0x03030303;
1899-
constexpr uint8_t scales_size = 12;
1900-
1901-
uint32_t sm[3];
1902-
memcpy(sm, scales_in, scales_size);
1903-
1904-
const uint32_t mins_0_3 = sm[1] & kmask1;
1905-
const uint32_t mins_4_7 = ((sm[2] >> 4) & kmask2) | (((sm[1] >> 6) & kmask3) << 4);
1906-
const uint32x2_t mins_u32 = { mins_0_3, mins_4_7 };
1907-
1908-
*out_mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins_u32)));
1909-
1910-
uint32_t scales_u32[2];
1911-
scales_u32[0] = sm[0] & kmask1;
1912-
scales_u32[1] = (sm[2] & kmask2) | (((sm[0] >> 6) & kmask3) << 4);
1913-
memcpy(out_scales, scales_u32, 8);
1914-
}
1915-
19161917

19171918
void ggml_gemm_q4_K_8x8_q8_K(int n,
19181919
float * GGML_RESTRICT s,
@@ -1943,6 +1944,7 @@ void ggml_gemm_q4_K_8x8_q8_K(int n,
19431944
UNUSED(ncols_interleaved);
19441945
UNUSED(blocklen);
19451946

1947+
#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
19461948
const uint8x16_t m4b = vdupq_n_u8(0x0f);
19471949

19481950
// 8 accumulators: 2 row pairs × 4 col pairs
@@ -1960,17 +1962,21 @@ void ggml_gemm_q4_K_8x8_q8_K(int n,
19601962

19611963
for (int b = 0; b < nb; b++) {
19621964
// bsums pairs belongs to the same q8_k subblock
1963-
const int16x8_t y_bsums[4]{
1965+
const int16x8_t bsums[4]{
19641966
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 0), vld1q_s16(q8_ptr[b].bsums + 16 * 0 + 8)),
19651967
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 1), vld1q_s16(q8_ptr[b].bsums + 16 * 1 + 8)),
19661968
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 2), vld1q_s16(q8_ptr[b].bsums + 16 * 2 + 8)),
19671969
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 3), vld1q_s16(q8_ptr[b].bsums + 16 * 3 + 8)),
19681970
};
1971+
int16_t bsums_arr[4][8];
1972+
for (int q8_row = 0; q8_row < 4; q8_row++) {
1973+
vst1q_s16(bsums_arr[q8_row], bsums[q8_row]);
1974+
}
19691975

19701976
int32x4_t sb_acc[4]; // Aux accumulators to store subblock (partial) results
19711977
int32x4_t acc[8]; // rows 01 stored in [0][1][2][3] rows 23 stored in [4][5][6][7]
19721978
int32x4_t bias_acc[8]; // interleaved bias_acc: [0]->r0 0123, [1]->r0 4567, [2]->r1 0123 ...
1973-
for (int i = 0; i < 8; ++i) {
1979+
for (int i = 0; i < 8; i++) {
19741980
acc[i] = vdupq_n_s32(0);
19751981
bias_acc[i] = vdupq_n_s32(0);
19761982
}
@@ -1992,7 +1998,7 @@ void ggml_gemm_q4_K_8x8_q8_K(int n,
19921998
int8x16_t q8_qs_23[8];
19931999

19942000
// Load 32-byte per row pair, 1 subblock each time
1995-
for (int i = 0; i < 8; ++i) {
2001+
for (int i = 0; i < 8; i++) {
19962002
const int offset = i * 32; // 16 for row 01, 16 for row 23
19972003
q8_qs_01[i] = vld1q_s8(q8_base + offset);
19982004
q8_qs_23[i] = vld1q_s8(q8_base + offset + 16);
@@ -2007,7 +2013,7 @@ void ggml_gemm_q4_K_8x8_q8_K(int n,
20072013

20082014
// Q4s columns iterated in pairs (01, 23, 45, 67)
20092015
for (int cp = 0; cp < ncols_interleaved / 2; cp++) {
2010-
for (int i = 0; i < 4; ++i) {
2016+
for (int i = 0; i < 4; i++) {
20112017
sb_acc[i] = vdupq_n_s32(0);
20122018
}
20132019

@@ -2063,16 +2069,16 @@ void ggml_gemm_q4_K_8x8_q8_K(int n,
20632069
for (int q8_row = 0; q8_row < 4; q8_row++) {
20642070
// Each pair of subblocks share the same bsums
20652071
// Load scalar bsum → broadcast to a vector (vdupq_n_s16(s)).
2066-
int16x8_t bsums_vec_lo = vdupq_n_s16(y_bsums[sb][q8_row * 2]);
2067-
int16x8_t bsums_vec_hi = vdupq_n_s16(y_bsums[sb][q8_row * 2 + 1]);
2072+
int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[sb][q8_row * 2]);
2073+
int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[sb][q8_row * 2 + 1]);
20682074

20692075
bias_acc[2 * q8_row] =
2070-
vmlal_s16(bias_acc[2 * q8_row], vget_low_s16(bsums_vec_lo), vget_low_s16(q4sb_mins[0]));
2076+
vmlal_s16(bias_acc[2 * q8_row], bsums_vec_lo, vget_low_s16(q4sb_mins[0]));
20712077
bias_acc[2 * q8_row] =
2072-
vmlal_s16(bias_acc[2 * q8_row], vget_low_s16(bsums_vec_hi), vget_low_s16(q4sb_mins[1]));
2073-
bias_acc[2 * q8_row + 1] = vmlal_s16(bias_acc[2 * q8_row + 1], vget_high_s16(bsums_vec_lo),
2078+
vmlal_s16(bias_acc[2 * q8_row], bsums_vec_hi, vget_low_s16(q4sb_mins[1]));
2079+
bias_acc[2 * q8_row + 1] = vmlal_s16(bias_acc[2 * q8_row + 1], bsums_vec_lo,
20742080
vget_high_s16(q4sb_mins[0]));
2075-
bias_acc[2 * q8_row + 1] = vmlal_s16(bias_acc[2 * q8_row + 1], vget_high_s16(bsums_vec_hi),
2081+
bias_acc[2 * q8_row + 1] = vmlal_s16(bias_acc[2 * q8_row + 1], bsums_vec_hi,
20762082
vget_high_s16(q4sb_mins[1]));
20772083
}
20782084
} // for sb
@@ -2095,19 +2101,13 @@ void ggml_gemm_q4_K_8x8_q8_K(int n,
20952101

20962102
for (int i = 0; i < q8_k_blocklen; i++) {
20972103
for (int j = 0; j < 2; j++) {
2098-
const float32x4_t dmins = {
2099-
q8_ptr[b].d[i] * GGML_CPU_FP16_TO_FP32(q4_ptr[b].dmin[j * 4 + 0]),
2100-
q8_ptr[b].d[i] * GGML_CPU_FP16_TO_FP32(q4_ptr[b].dmin[j * 4 + 1]),
2101-
q8_ptr[b].d[i] * GGML_CPU_FP16_TO_FP32(q4_ptr[b].dmin[j * 4 + 2]),
2102-
q8_ptr[b].d[i] * GGML_CPU_FP16_TO_FP32(q4_ptr[b].dmin[j * 4 + 3]),
2103-
};
2104+
// TODO: Change to a single vmul
2105+
float32x4_t q8_d = vdupq_n_f32(q8_ptr[b].d[i]);
2106+
float32x4_t q4_dmin = vcvt_f32_f16(vld1_f16((const __fp16 *)(q4_ptr[b].dmin + j * 4)));
2107+
const float32x4_t dmins = vmulq_f32(q4_dmin, q8_d);
21042108

2105-
const float32x4_t scale = {
2106-
q8_ptr[b].d[i] * GGML_CPU_FP16_TO_FP32(q4_ptr[b].d[j * 4 + 0]),
2107-
q8_ptr[b].d[i] * GGML_CPU_FP16_TO_FP32(q4_ptr[b].d[j * 4 + 1]),
2108-
q8_ptr[b].d[i] * GGML_CPU_FP16_TO_FP32(q4_ptr[b].d[j * 4 + 2]),
2109-
q8_ptr[b].d[i] * GGML_CPU_FP16_TO_FP32(q4_ptr[b].d[j * 4 + 3]),
2110-
};
2109+
float32x4_t q4_d = vcvt_f32_f16(vld1_f16((const __fp16 *)(q4_ptr[b].d + j * 4)));
2110+
const float32x4_t scale = vmulq_f32(q4_d, q8_d);
21112111

21122112
acc_f32[2 * i + j] = vmlsq_f32(acc_f32[2 * i + j], vcvtq_f32_s32(bias_acc[2 * i + j]), dmins);
21132113
acc_f32[2 * i + j] =
@@ -2127,5 +2127,8 @@ void ggml_gemm_q4_K_8x8_q8_K(int n,
21272127
}
21282128
} // for x
21292129
} // for y
2130+
return;
2131+
#endif
2132+
ggml_gemm_q4_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
21302133
}
21312134

0 commit comments

Comments
 (0)