@@ -525,67 +525,47 @@ void ggml_gemv_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void *
525525 UNUSED (ncols_interleaved );
526526 UNUSED (blocklen );
527527
528- #if ! ((defined(_MSC_VER )) && ! defined(__clang__ )) && defined(__aarch64__ ) && defined(__ARM_NEON )
528+ #if ! ((defined(_MSC_VER )) && ! defined(__clang__ )) && defined(__aarch64__ ) && defined(__ARM_NEON ) && defined( __ARM_FEATURE_DOTPROD )
529529 if (ggml_cpu_has_neon () && ggml_cpu_has_dotprod ()) {
530- const void * b_ptr = vx ;
531- const void * a_ptr = vy ;
532- float * res_ptr = s ;
533-
534- __asm__ __volatile__(
535- "movi v31.16b, #0x4\n"
536- "movi v30.16b, #0xf0\n"
537- "add %x[b_ptr], %x[b_ptr], #0x8\n"
538- "1:" // Column loop
539- "add x22, %x[a_ptr], #0x2\n"
540- "movi v29.16b, #0x0\n"
541- "mov x21, %x[nb]\n"
542- "2:" // Block loop
543- "ldr q28, [%x[b_ptr], #0x0]\n"
544- "ldr q27, [x22, #0x0]\n"
545- "movi v26.4s, #0x0\n"
546- "sub x20, x22, #0x2\n"
547- "ldr q25, [x22, #0x10]\n"
548- "ldr q24, [%x[b_ptr], #0x10]\n"
549- "sub x21, x21, #0x1\n"
550- "add x22, x22, #0x22\n"
551- "ldr q23, [%x[b_ptr], #0x20]\n"
552- "ldr q22, [%x[b_ptr], #0x30]\n"
553- "ld1r { v21.8h }, [x20]\n"
554- "ldr q20, [%x[b_ptr], #-0x8]\n"
555- "sshl v16.16b, v28.16b, v31.16b\n"
556- "and v28.16b, v28.16b, v30.16b\n"
557- "sshl v19.16b, v24.16b, v31.16b\n"
558- "and v24.16b, v24.16b, v30.16b\n"
559- "add %x[b_ptr], %x[b_ptr], #0x48\n"
560- "sshl v18.16b, v23.16b, v31.16b\n"
561- "and v23.16b, v23.16b, v30.16b\n"
562- ".inst 0x4f9be21a // sdot v26.4s, v16.16b, v27.4b[0]\n"
563- "sshl v17.16b, v22.16b, v31.16b\n"
564- "and v22.16b, v22.16b, v30.16b\n"
565- "fcvtl v21.4s, v21.4h\n"
566- "fcvtl v16.4s, v20.4h\n"
567- ".inst 0x4f99e39a // sdot v26.4s, v28.16b, v25.4b[0]\n"
568- "fmul v16.4s, v16.4s, v21.4s\n"
569- ".inst 0x4fbbe27a // sdot v26.4s, v19.16b, v27.4b[1]\n"
570- ".inst 0x4fb9e31a // sdot v26.4s, v24.16b, v25.4b[1]\n"
571- ".inst 0x4f9bea5a // sdot v26.4s, v18.16b, v27.4b[2]\n"
572- ".inst 0x4f99eafa // sdot v26.4s, v23.16b, v25.4b[2]\n"
573- ".inst 0x4fbbea3a // sdot v26.4s, v17.16b, v27.4b[3]\n"
574- ".inst 0x4fb9eada // sdot v26.4s, v22.16b, v25.4b[3]\n"
575- "scvtf v26.4s, v26.4s, #0x4\n"
576- "fmla v29.4s, v26.4s, v16.4s\n"
577- "cbnz x21, 2b\n"
578- "sub %x[nc], %x[nc], #0x4\n"
579- "str q29, [%x[res_ptr], #0x0]\n"
580- "add %x[res_ptr], %x[res_ptr], #0x10\n"
581- "cbnz %x[nc], 1b\n"
582- : [b_ptr ] "+&r" (b_ptr ), [res_ptr ] "+&r" (res_ptr ), [nc ] "+&r" (nc )
583- : [a_ptr ] "r" (a_ptr ), [nb ] "r" (nb )
584- : "memory" , "v16" , "v17" , "v18" , "v19" , "v20" , "v21" , "v22" , "v23" , "v24" , "v25" , "v26" , "v27" , "v28" , "v29" , "v30" , "v31" , "x20" , "x21" , "x22"
585- );
530+ const block_q4_0x4 * b_ptr = (const block_q4_0x4 * )vx ;
531+
532+ for (int c = 0 ; c < nc ; c += ncols_interleaved ) {
533+ const block_q8_0 * a_ptr = (const block_q8_0 * )vy ;
534+ float32x4_t acc = vdupq_n_f32 (0 );
535+ for (int b = 0 ; b < nb ; b ++ ) {
536+ int8x16_t b0 = vld1q_s8 ((const int8_t * )b_ptr -> qs );
537+ int8x16_t b1 = vld1q_s8 ((const int8_t * )b_ptr -> qs + 16 );
538+ int8x16_t b2 = vld1q_s8 ((const int8_t * )b_ptr -> qs + 32 );
539+ int8x16_t b3 = vld1q_s8 ((const int8_t * )b_ptr -> qs + 48 );
540+ float16x4_t bd = vld1_f16 ((const __fp16 * )b_ptr -> d );
541+
542+ int8x16_t a0 = vld1q_s8 (a_ptr -> qs );
543+ int8x16_t a1 = vld1q_s8 (a_ptr -> qs + qk /2 );
544+ float16x4_t ad = vld1_dup_f16 ((const __fp16 * )& a_ptr -> d );
545+
546+ int32x4_t ret = vdupq_n_s32 (0 );
547+
548+ ret = vdotq_laneq_s32 (ret , b0 << 4 , a0 , 0 );
549+ ret = vdotq_laneq_s32 (ret , b1 << 4 , a0 , 1 );
550+ ret = vdotq_laneq_s32 (ret , b2 << 4 , a0 , 2 );
551+ ret = vdotq_laneq_s32 (ret , b3 << 4 , a0 , 3 );
552+
553+ ret = vdotq_laneq_s32 (ret , b0 & 0xf0U , a1 , 0 );
554+ ret = vdotq_laneq_s32 (ret , b1 & 0xf0U , a1 , 1 );
555+ ret = vdotq_laneq_s32 (ret , b2 & 0xf0U , a1 , 2 );
556+ ret = vdotq_laneq_s32 (ret , b3 & 0xf0U , a1 , 3 );
557+
558+ acc = vfmaq_f32 (acc , vcvtq_n_f32_s32 (ret , 4 ),
559+ vmulq_f32 (vcvt_f32_f16 (ad ), vcvt_f32_f16 (bd )));
560+ a_ptr ++ ;
561+ b_ptr ++ ;
562+ }
563+ vst1q_f32 (s , acc );
564+ s += ncols_interleaved ;
565+ }
586566 return ;
587567 }
588- #endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
568+ #endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
589569 float sumf [4 ];
590570 int sumi ;
591571
0 commit comments