Skip to content

Commit 18d20c0

Browse files
matmul-int8: enable matmul-int8 with MSVC and fix Clang warnings
1 parent 7d46953 commit 18d20c0

File tree

2 files changed

+40
-32
lines changed

2 files changed

+40
-32
lines changed

CMakeLists.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1007,6 +1007,11 @@ if (CMAKE_OSX_ARCHITECTURES STREQUAL "arm64" OR CMAKE_GENERATOR_PLATFORM_LWR STR
10071007
if (GGML_COMPILER_SUPPORT_DOTPROD)
10081008
add_compile_definitions(__ARM_FEATURE_DOTPROD)
10091009
endif ()
1010+
check_cxx_source_compiles("#include <arm_neon.h>\nint main() { int8x16_t _a, _b; int32x4_t _s = vmlaq_f32(_s, _a, _b); return 0; }" GGML_COMPILER_SUPPORT_MATMUL_INT8)
1011+
if (GGML_COMPILER_SUPPORT_MATMUL_INT8)
1012+
add_compile_definitions(__ARM_FEATURE_MATMUL_INT8)
1013+
endif ()
1014+
10101015
check_cxx_source_compiles("#include <arm_neon.h>\nint main() { float16_t _a; float16x8_t _s = vdupq_n_f16(_a); return 0; }" GGML_COMPILER_SUPPORT_FP16_VECTOR_ARITHMETIC)
10111016
if (GGML_COMPILER_SUPPORT_FP16_VECTOR_ARITHMETIC)
10121017
add_compile_definitions(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)

ggml-quants.c

Lines changed: 35 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -3487,10 +3487,9 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
34873487
#if defined(__ARM_FEATURE_MATMUL_INT8)
34883488
if (nrc == 2) {
34893489
const block_q4_0 * restrict vx0 = vx;
3490-
const block_q4_0 * restrict vx1 = vx + bx;
3491-
3490+
const block_q4_0 * restrict vx1 = (const block_q4_0 *) ((const uint8_t*)vx + bx);
34923491
const block_q8_0 * restrict vy0 = vy;
3493-
const block_q8_0 * restrict vy1 = vy + by;
3492+
const block_q8_0 * restrict vy1 = (const block_q8_0 *) ((const uint8_t*)vy + by);
34943493

34953494
float32x4_t sumv0 = vdupq_n_f32(0.0f);
34963495

@@ -3524,10 +3523,12 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
35243523
const int8x16_t y1_l = vld1q_s8(b_y1->qs);
35253524
const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16);
35263525

3527-
float32x4_t scale = {GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y0->d),
3528-
GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y1->d),
3529-
GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y0->d),
3530-
GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y1->d)};
3526+
float32_t _scale[4] = { GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y0->d),
3527+
GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y1->d),
3528+
GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y0->d),
3529+
GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y1->d)};
3530+
3531+
float32x4_t scale = vld1q_f32(_scale);
35313532

35323533
int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
35333534
int8x16_t l1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
@@ -3894,9 +3895,9 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r
38943895
#if defined(__ARM_FEATURE_MATMUL_INT8)
38953896
if (nrc == 2) {
38963897
const block_q4_1 * restrict vx0 = vx;
3897-
const block_q4_1 * restrict vx1 = vx + bx;
3898+
const block_q4_1 * restrict vx1 = (const block_q4_1 *) ((const uint8_t*)vx + bx);
38983899
const block_q8_1 * restrict vy0 = vy;
3899-
const block_q8_1 * restrict vy1 = vy + by;
3900+
const block_q8_1 * restrict vy1 = (const block_q8_1 *) ((const uint8_t*)vy + by);
39003901

39013902
float32x4_t sumv0 = vdupq_n_f32(0.0f);
39023903
float32x4_t summs0 = vdupq_n_f32(0.0f);
@@ -3907,11 +3908,11 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r
39073908
const block_q8_1 * restrict b_y0 = &vy0[i];
39083909
const block_q8_1 * restrict b_y1 = &vy1[i];
39093910

3910-
float32x4_t summs_t = {GGML_FP16_TO_FP32(b_x0->m) * GGML_FP16_TO_FP32(b_y0->s),
3911-
GGML_FP16_TO_FP32(b_x1->m) * GGML_FP16_TO_FP32(b_y0->s),
3912-
GGML_FP16_TO_FP32(b_x0->m) * GGML_FP16_TO_FP32(b_y1->s),
3913-
GGML_FP16_TO_FP32(b_x1->m) * GGML_FP16_TO_FP32(b_y1->s)};
3914-
summs0 += summs_t;
3911+
float32_t summs_t[4] = {GGML_FP16_TO_FP32(b_x0->m) * GGML_FP16_TO_FP32(b_y0->s),
3912+
GGML_FP16_TO_FP32(b_x1->m) * GGML_FP16_TO_FP32(b_y0->s),
3913+
GGML_FP16_TO_FP32(b_x0->m) * GGML_FP16_TO_FP32(b_y1->s),
3914+
GGML_FP16_TO_FP32(b_x1->m) * GGML_FP16_TO_FP32(b_y1->s)};
3915+
summs0 = vaddq_f32(summs0, vld1q_f32(summs_t));
39153916

39163917
const uint8x16_t m4b = vdupq_n_u8(0x0F);
39173918

@@ -3931,10 +3932,11 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r
39313932
const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16);
39323933

39333934
// mmla into int32x4_t
3934-
float32x4_t scale = {GGML_FP16_TO_FP32(b_x0->d)*b_y0->d,
3935-
GGML_FP16_TO_FP32(b_x0->d)*b_y1->d,
3936-
GGML_FP16_TO_FP32(b_x1->d)*b_y0->d,
3937-
GGML_FP16_TO_FP32(b_x1->d)*b_y1->d};
3935+
float32_t _scale[4] = {GGML_FP16_TO_FP32(b_x0->d)*b_y0->d,
3936+
GGML_FP16_TO_FP32(b_x0->d)*b_y1->d,
3937+
GGML_FP16_TO_FP32(b_x1->d)*b_y0->d,
3938+
GGML_FP16_TO_FP32(b_x1->d)*b_y1->d};
3939+
float32x4_t scale = vld1q_f32(_scale);
39383940

39393941
int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
39403942
int8x16_t l1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
@@ -3953,7 +3955,7 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r
39533955

39543956
float32x4_t sumv1 = vextq_f32(sumv0, sumv0, 2);
39553957
float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1);
3956-
sumv2 = sumv2 + summs0;
3958+
sumv2 = vaddq_f32(sumv2, summs0);
39573959

39583960
vst1_f32(s, vget_low_f32(sumv2));
39593961
vst1_f32(s + bs, vget_high_f32(sumv2));
@@ -4836,35 +4838,36 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r
48364838

48374839
#if defined(__ARM_FEATURE_MATMUL_INT8)
48384840
if (nrc == 2) {
4839-
const block_q8_0 * restrict vx0 = vx;
4840-
const block_q8_0 * restrict vx1 = vx + bx;
4841+
const block_q4_0 * restrict vx0 = vx;
4842+
const block_q4_0 * restrict vx1 = (const block_q4_0 *) ((const uint8_t*)vx + bx);
48414843
const block_q8_0 * restrict vy0 = vy;
4842-
const block_q8_0 * restrict vy1 = vy + by;
4844+
const block_q8_0 * restrict vy1 = (const block_q8_0 *) ((const uint8_t*)vy + by);
48434845

48444846
float32x4_t sumv0 = vdupq_n_f32(0.0f);
48454847

48464848
for (int i = 0; i < nb; i++) {
4847-
const block_q8_0 * restrict b_x0 = &vx0[i];
4849+
const block_q4_0 * restrict b_x0 = &vx0[i];
48484850
const block_q8_0 * restrict b_y0 = &vy0[i];
48494851

4850-
const block_q8_0 * restrict b_x1 = &vx1[i];
4852+
const block_q4_0 * restrict b_x1 = &vx1[i];
48514853
const block_q8_0 * restrict b_y1 = &vy1[i];
48524854

4853-
const int8x16_t x0_l = vld1q_s8(b_x0->qs);
4854-
const int8x16_t x0_h = vld1q_s8(b_x0->qs + 16);
4855-
const int8x16_t x1_l = vld1q_s8(b_x1->qs);
4856-
const int8x16_t x1_h = vld1q_s8(b_x1->qs + 16);
4855+
const int8x16_t x0_l = vld1q_s8((const int8_t*)b_x0->qs);
4856+
const int8x16_t x0_h = vld1q_s8((const int8_t*)b_x0->qs + 16);
4857+
const int8x16_t x1_l = vld1q_s8((const int8_t*)b_x1->qs);
4858+
const int8x16_t x1_h = vld1q_s8((const int8_t*)b_x1->qs + 16);
48574859

48584860
// load y
48594861
const int8x16_t y0_l = vld1q_s8(b_y0->qs);
48604862
const int8x16_t y0_h = vld1q_s8(b_y0->qs + 16);
48614863
const int8x16_t y1_l = vld1q_s8(b_y1->qs);
48624864
const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16);
48634865

4864-
float32x4_t scale = {GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y0->d),
4865-
GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y1->d),
4866-
GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y0->d),
4867-
GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y1->d)};
4866+
float32_t _scale[4] = {GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y0->d),
4867+
GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y1->d),
4868+
GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y0->d),
4869+
GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y1->d)};
4870+
float32x4_t scale = vld1q_f32(_scale);
48684871

48694872
int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
48704873
int8x16_t l1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));

0 commit comments

Comments
 (0)