Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 38 additions & 58 deletions ggml/src/ggml-cpu/ggml-cpu-aarch64.c
Original file line number Diff line number Diff line change
Expand Up @@ -525,67 +525,47 @@ void ggml_gemv_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void *
UNUSED(ncols_interleaved);
UNUSED(blocklen);

#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
const void * b_ptr = vx;
const void * a_ptr = vy;
float * res_ptr = s;

__asm__ __volatile__(
"movi v31.16b, #0x4\n"
"movi v30.16b, #0xf0\n"
"add %x[b_ptr], %x[b_ptr], #0x8\n"
"1:" // Column loop
"add x22, %x[a_ptr], #0x2\n"
"movi v29.16b, #0x0\n"
"mov x21, %x[nb]\n"
"2:" // Block loop
"ldr q28, [%x[b_ptr], #0x0]\n"
"ldr q27, [x22, #0x0]\n"
"movi v26.4s, #0x0\n"
"sub x20, x22, #0x2\n"
"ldr q25, [x22, #0x10]\n"
"ldr q24, [%x[b_ptr], #0x10]\n"
"sub x21, x21, #0x1\n"
"add x22, x22, #0x22\n"
"ldr q23, [%x[b_ptr], #0x20]\n"
"ldr q22, [%x[b_ptr], #0x30]\n"
"ld1r { v21.8h }, [x20]\n"
"ldr q20, [%x[b_ptr], #-0x8]\n"
"sshl v16.16b, v28.16b, v31.16b\n"
"and v28.16b, v28.16b, v30.16b\n"
"sshl v19.16b, v24.16b, v31.16b\n"
"and v24.16b, v24.16b, v30.16b\n"
"add %x[b_ptr], %x[b_ptr], #0x48\n"
"sshl v18.16b, v23.16b, v31.16b\n"
"and v23.16b, v23.16b, v30.16b\n"
".inst 0x4f9be21a // sdot v26.4s, v16.16b, v27.4b[0]\n"
"sshl v17.16b, v22.16b, v31.16b\n"
"and v22.16b, v22.16b, v30.16b\n"
"fcvtl v21.4s, v21.4h\n"
"fcvtl v16.4s, v20.4h\n"
".inst 0x4f99e39a // sdot v26.4s, v28.16b, v25.4b[0]\n"
"fmul v16.4s, v16.4s, v21.4s\n"
".inst 0x4fbbe27a // sdot v26.4s, v19.16b, v27.4b[1]\n"
".inst 0x4fb9e31a // sdot v26.4s, v24.16b, v25.4b[1]\n"
".inst 0x4f9bea5a // sdot v26.4s, v18.16b, v27.4b[2]\n"
".inst 0x4f99eafa // sdot v26.4s, v23.16b, v25.4b[2]\n"
".inst 0x4fbbea3a // sdot v26.4s, v17.16b, v27.4b[3]\n"
".inst 0x4fb9eada // sdot v26.4s, v22.16b, v25.4b[3]\n"
"scvtf v26.4s, v26.4s, #0x4\n"
"fmla v29.4s, v26.4s, v16.4s\n"
"cbnz x21, 2b\n"
"sub %x[nc], %x[nc], #0x4\n"
"str q29, [%x[res_ptr], #0x0]\n"
"add %x[res_ptr], %x[res_ptr], #0x10\n"
"cbnz %x[nc], 1b\n"
: [b_ptr] "+&r" (b_ptr), [res_ptr] "+&r" (res_ptr), [nc] "+&r" (nc)
: [a_ptr] "r" (a_ptr), [nb] "r" (nb)
: "memory", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x20", "x21", "x22"
);
const block_q4_0x4 * b_ptr = (const block_q4_0x4 *)vx;

for (int c = 0; c < nc; c += ncols_interleaved) {
const block_q8_0 * a_ptr = (const block_q8_0 *)vy;
float32x4_t acc = vdupq_n_f32(0);
for (int b = 0; b < nb; b++) {
int8x16_t b0 = vld1q_s8((const int8_t *)b_ptr->qs);
int8x16_t b1 = vld1q_s8((const int8_t *)b_ptr->qs + 16);
int8x16_t b2 = vld1q_s8((const int8_t *)b_ptr->qs + 32);
int8x16_t b3 = vld1q_s8((const int8_t *)b_ptr->qs + 48);
float16x4_t bd = vld1_f16((const __fp16 *)b_ptr->d);

int8x16_t a0 = vld1q_s8(a_ptr->qs);
int8x16_t a1 = vld1q_s8(a_ptr->qs + qk/2);
float16x4_t ad = vld1_dup_f16((const __fp16 *)&a_ptr->d);

int32x4_t ret = vdupq_n_s32(0);

ret = vdotq_laneq_s32(ret, b0 << 4, a0, 0);
ret = vdotq_laneq_s32(ret, b1 << 4, a0, 1);
ret = vdotq_laneq_s32(ret, b2 << 4, a0, 2);
ret = vdotq_laneq_s32(ret, b3 << 4, a0, 3);

ret = vdotq_laneq_s32(ret, b0 & 0xf0U, a1, 0);
ret = vdotq_laneq_s32(ret, b1 & 0xf0U, a1, 1);
ret = vdotq_laneq_s32(ret, b2 & 0xf0U, a1, 2);
ret = vdotq_laneq_s32(ret, b3 & 0xf0U, a1, 3);

acc = vfmaq_f32(acc, vcvtq_n_f32_s32(ret, 4),
vmulq_f32(vcvt_f32_f16(ad), vcvt_f32_f16(bd)));
a_ptr++;
b_ptr++;
}
vst1q_f32(s, acc);
s += ncols_interleaved;
}
return;
}
#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
float sumf[4];
int sumi;

Expand Down
Loading