@@ -991,6 +991,81 @@ void ggml_gemv_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
991991 }
992992 }
993993 return ;
994+ #elif defined(__riscv_v_intrinsic )
995+ if (__riscv_vlenb () >= QK4_0 ) {
996+ const size_t vl = QK4_0 ;
997+ const vuint8m1_t lhs_idx_m1 = __riscv_vand_vx_u8m1 (__riscv_vid_v_u8m1 (vl ), 7 , vl );
998+ const vuint8m2_t lhs_idx_m2 = __riscv_vcreate_v_u8m1_u8m2 (lhs_idx_m1 , lhs_idx_m1 );
999+ const vuint8m2_t lhs_idx_m2_hi = __riscv_vadd_vx_u8m2 (lhs_idx_m2 , 8 , vl );
1000+ const vuint8m4_t lhs_idx_m4 = __riscv_vcreate_v_u8m2_u8m4 (lhs_idx_m2 , lhs_idx_m2_hi );
1001+ const vbool2_t mask0 = __riscv_vreinterpret_v_u16m1_b2 (__riscv_vreinterpret_v_u64m1_u16m1 (__riscv_vmv_v_x_u64m1 (0x00000000000000FFull , vl / 8 )));
1002+ const vbool2_t mask1 = __riscv_vreinterpret_v_u16m1_b2 (__riscv_vreinterpret_v_u64m1_u16m1 (__riscv_vmv_v_x_u64m1 (0x000000000000FF00ull , vl / 8 )));
1003+ const vbool2_t mask2 = __riscv_vreinterpret_v_u16m1_b2 (__riscv_vreinterpret_v_u64m1_u16m1 (__riscv_vmv_v_x_u64m1 (0x0000000000FF0000ull , vl / 8 )));
1004+ const vbool2_t mask3 = __riscv_vreinterpret_v_u16m1_b2 (__riscv_vreinterpret_v_u64m1_u16m1 (__riscv_vmv_v_x_u64m1 (0x00000000FF000000ull , vl / 8 )));
1005+ const vbool2_t mask4 = __riscv_vreinterpret_v_u16m1_b2 (__riscv_vreinterpret_v_u64m1_u16m1 (__riscv_vmv_v_x_u64m1 (0x000000FF00000000ull , vl / 8 )));
1006+ const vbool2_t mask5 = __riscv_vreinterpret_v_u16m1_b2 (__riscv_vreinterpret_v_u64m1_u16m1 (__riscv_vmv_v_x_u64m1 (0x0000FF0000000000ull , vl / 8 )));
1007+ const vbool2_t mask6 = __riscv_vreinterpret_v_u16m1_b2 (__riscv_vreinterpret_v_u64m1_u16m1 (__riscv_vmv_v_x_u64m1 (0x00FF000000000000ull , vl / 8 )));
1008+ const vbool2_t mask7 = __riscv_vreinterpret_v_u16m1_b2 (__riscv_vreinterpret_v_u64m1_u16m1 (__riscv_vmv_v_x_u64m1 (0xFF00000000000000ull , vl / 8 )));
1009+
1010+ const block_q8_0 * a_ptr = (const block_q8_0 * ) vy ;
1011+ for (int x = 0 ; x < nc / ncols_interleaved ; x ++ ) {
1012+ const block_q4_0x8 * b_ptr = (const block_q4_0x8 * ) vx + (x * nb );
1013+
1014+ vfloat32m1_t sumf = __riscv_vfmv_v_f_f32m1 (0.0 , vl / 4 );
1015+ for (int l = 0 ; l < nb ; l ++ ) {
1016+ const vint8m1_t lhs_raw_vec = __riscv_vle8_v_i8m1 (a_ptr [l ].qs , vl );
1017+ const vint8m4_t lhs_raw_vec_lo = __riscv_vset_v_i8m1_i8m4 (__riscv_vundefined_i8m4 (), 0 , lhs_raw_vec );
1018+ const vint8m4_t lhs_raw_vec_hi = __riscv_vset_v_i8m1_i8m4 (__riscv_vundefined_i8m4 (), 0 , __riscv_vslidedown_vx_i8m1 (lhs_raw_vec , 16 , vl ));
1019+ const vint8m4_t lhs_vec_lo = __riscv_vrgather_vv_i8m4 (lhs_raw_vec_lo , lhs_idx_m4 , vl * 4 );
1020+ const vint8m4_t lhs_vec_hi = __riscv_vrgather_vv_i8m4 (lhs_raw_vec_hi , lhs_idx_m4 , vl * 4 );
1021+
1022+ const vint8m4_t rhs_raw_vec = __riscv_vle8_v_i8m4 ((const int8_t * )b_ptr [l ].qs , vl * 4 );
1023+ const vint8m4_t rhs_vec_lo = __riscv_vsra_vx_i8m4 (__riscv_vsll_vx_i8m4 (rhs_raw_vec , 4 , vl * 4 ), 4 , vl * 4 );
1024+ const vint8m4_t rhs_vec_hi = __riscv_vsra_vx_i8m4 (rhs_raw_vec , 4 , vl * 4 );
1025+
1026+ const vint16m8_t sumi_lo = __riscv_vwmul_vv_i16m8 (rhs_vec_lo , lhs_vec_lo , vl * 4 );
1027+ const vint16m8_t sumi_hi = __riscv_vwmul_vv_i16m8 (rhs_vec_hi , lhs_vec_hi , vl * 4 );
1028+ const vint16m8_t sumi = __riscv_vadd_vv_i16m8 (sumi_lo , sumi_hi , vl * 4 );
1029+
1030+ const vint32m1_t iaccz = __riscv_vmv_v_x_i32m1 (0 , vl / 4 );
1031+ const vint32m1_t iacc7 = __riscv_vwredsum_vs_i16m8_i32m1_m (mask7 , sumi , iaccz , vl * 4 );
1032+ const vint32m1_t iacc7s = __riscv_vslideup_vx_i32m1 (iacc7 , iacc7 , 1 , vl / 4 );
1033+ const vint32m1_t iacc6 = __riscv_vwredsum_vs_i16m8_i32m1_tum (mask6 , iacc7s , sumi , iaccz , vl * 4 );
1034+ const vint32m1_t iacc6s = __riscv_vslideup_vx_i32m1 (iacc6 , iacc6 , 1 , vl / 4 );
1035+ const vint32m1_t iacc5 = __riscv_vwredsum_vs_i16m8_i32m1_tum (mask5 , iacc6s , sumi , iaccz , vl * 4 );
1036+ const vint32m1_t iacc5s = __riscv_vslideup_vx_i32m1 (iacc5 , iacc5 , 1 , vl / 4 );
1037+ const vint32m1_t iacc4 = __riscv_vwredsum_vs_i16m8_i32m1_tum (mask4 , iacc5s , sumi , iaccz , vl * 4 );
1038+ const vint32m1_t iacc4s = __riscv_vslideup_vx_i32m1 (iacc4 , iacc4 , 1 , vl / 4 );
1039+ const vint32m1_t iacc3 = __riscv_vwredsum_vs_i16m8_i32m1_tum (mask3 , iacc4s , sumi , iaccz , vl * 4 );
1040+ const vint32m1_t iacc3s = __riscv_vslideup_vx_i32m1 (iacc3 , iacc3 , 1 , vl / 4 );
1041+ const vint32m1_t iacc2 = __riscv_vwredsum_vs_i16m8_i32m1_tum (mask2 , iacc3s , sumi , iaccz , vl * 4 );
1042+ const vint32m1_t iacc2s = __riscv_vslideup_vx_i32m1 (iacc2 , iacc2 , 1 , vl / 4 );
1043+ const vint32m1_t iacc1 = __riscv_vwredsum_vs_i16m8_i32m1_tum (mask1 , iacc2s , sumi , iaccz , vl * 4 );
1044+ const vint32m1_t iacc1s = __riscv_vslideup_vx_i32m1 (iacc1 , iacc1 , 1 , vl / 4 );
1045+ const vint32m1_t iacc0 = __riscv_vwredsum_vs_i16m8_i32m1_tum (mask0 , iacc1s , sumi , iaccz , vl * 4 );
1046+ const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1 (iacc0 , vl / 4 );
1047+
1048+ // vector version needs Zvfhmin extension
1049+ const float a_scale = GGML_FP16_TO_FP32 (a_ptr [l ].d );
1050+ const float b_scales [8 ] = {
1051+ GGML_FP16_TO_FP32 (b_ptr [l ].d [0 ]),
1052+ GGML_FP16_TO_FP32 (b_ptr [l ].d [1 ]),
1053+ GGML_FP16_TO_FP32 (b_ptr [l ].d [2 ]),
1054+ GGML_FP16_TO_FP32 (b_ptr [l ].d [3 ]),
1055+ GGML_FP16_TO_FP32 (b_ptr [l ].d [4 ]),
1056+ GGML_FP16_TO_FP32 (b_ptr [l ].d [5 ]),
1057+ GGML_FP16_TO_FP32 (b_ptr [l ].d [6 ]),
1058+ GGML_FP16_TO_FP32 (b_ptr [l ].d [7 ])
1059+ };
1060+ const vfloat32m1_t b_scales_vec = __riscv_vle32_v_f32m1 (b_scales , vl / 4 );
1061+ const vfloat32m1_t tmp1 = __riscv_vfmul_vf_f32m1 (facc , a_scale , vl / 4 );
1062+ const vfloat32m1_t tmp2 = __riscv_vfmul_vv_f32m1 (tmp1 , b_scales_vec , vl / 4 );
1063+ sumf = __riscv_vfadd_vv_f32m1 (sumf , tmp2 , vl / 4 );
1064+ }
1065+ __riscv_vse32_v_f32m1 (s + x * ncols_interleaved , sumf , vl / 4 );
1066+ }
1067+ return ;
1068+ }
9941069#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__)
9951070 {
9961071 float sumf [8 ];
0 commit comments