@@ -860,7 +860,7 @@ void ggml_vec_dot_mxfp6_e3m2_q8_0(int n, float * GGML_RESTRICT s, size_t bs, con
860860 int ib = 0 ;
861861 float sumf = 0 ;
862862
863- #if 0 // defined __AVX2__
863+ #if defined __AVX2__
864864 __m256 accum_ps = _mm256_setzero_ps ();
865865
866866 for (; ib + 1 < nb ; ib += 2 ) {
@@ -969,6 +969,134 @@ void ggml_vec_dot_mxfp6_e3m2_q8_0(int n, float * GGML_RESTRICT s, size_t bs, con
969969 * s = sumf ;
970970}
971971
972+ void ggml_vec_dot_mxfp6_e2m3_q8_0 (int n , float * GGML_RESTRICT s , size_t bs , const void * GGML_RESTRICT vx , size_t bx , const void * GGML_RESTRICT vy , size_t by , int nrc ) {
973+ assert (nrc == 1 );
974+ UNUSED (nrc );
975+ UNUSED (bx );
976+ UNUSED (by );
977+ UNUSED (bs );
978+ assert (n % QK_MXFP6_E2M3 == 0 );
979+ static_assert (QK_MXFP6_E2M3 == QK8_0 , "QK_MXFP6_E2M3 and QK8_0 must be the same" );
980+ assert (QK_MXFP6_E2M3 == 32 );
981+
982+ const block_mxfp6_e2m3 * GGML_RESTRICT x = vx ;
983+ const block_q8_0 * GGML_RESTRICT y = vy ;
984+
985+ const int nb = n / QK_MXFP6_E2M3 ;
986+
987+ int ib = 0 ;
988+ float sumf = 0 ;
989+
990+ #if defined __AVX2__
991+ __m256 accum_ps = _mm256_setzero_ps ();
992+
993+ for (; ib + 1 < nb ; ib += 2 ) {
994+ const block_mxfp6_e2m3 * x1 = & x [ib + 0 ];
995+ const block_q8_0 * y1 = & y [ib + 0 ];
996+
997+ const block_mxfp6_e2m3 * x2 = & x [ib + 1 ];
998+ const block_q8_0 * y2 = & y [ib + 1 ];
999+
1000+ int16_t k_vals_1 [32 ];
1001+ {
1002+ const uint8_t * q3 = x1 -> qs ;
1003+ for (int j = 0 ; j < 8 ; ++ j ) {
1004+ const uint8_t b0 = q3 [0 ];
1005+ const uint8_t b1 = q3 [1 ];
1006+ const uint8_t b2 = q3 [2 ];
1007+ k_vals_1 [4 * j + 0 ] = kvalues_mxfp6_e2m3 [b0 & 0x3F ];
1008+ k_vals_1 [4 * j + 1 ] = kvalues_mxfp6_e2m3 [(b0 >> 6 ) | ((b1 & 0x0F ) << 2 )];
1009+ k_vals_1 [4 * j + 2 ] = kvalues_mxfp6_e2m3 [(b1 >> 4 ) | ((b2 & 0x03 ) << 4 )];
1010+ k_vals_1 [4 * j + 3 ] = kvalues_mxfp6_e2m3 [b2 >> 2 ];
1011+ q3 += 3 ;
1012+ }
1013+ }
1014+
1015+ int16_t k_vals_2 [32 ];
1016+ {
1017+ const uint8_t * q3 = x2 -> qs ;
1018+ for (int j = 0 ; j < 8 ; ++ j ) {
1019+ const uint8_t b0 = q3 [0 ];
1020+ const uint8_t b1 = q3 [1 ];
1021+ const uint8_t b2 = q3 [2 ];
1022+ k_vals_2 [4 * j + 0 ] = kvalues_mxfp6_e2m3 [b0 & 0x3F ];
1023+ k_vals_2 [4 * j + 1 ] = kvalues_mxfp6_e2m3 [(b0 >> 6 ) | ((b1 & 0x0F ) << 2 )];
1024+ k_vals_2 [4 * j + 2 ] = kvalues_mxfp6_e2m3 [(b1 >> 4 ) | ((b2 & 0x03 ) << 4 )];
1025+ k_vals_2 [4 * j + 3 ] = kvalues_mxfp6_e2m3 [b2 >> 2 ];
1026+ q3 += 3 ;
1027+ }
1028+ }
1029+
1030+ const __m256i k_1_lo = _mm256_load_si256 ((const __m256i * )(k_vals_1 + 0 )); // k-vals 0-15
1031+ const __m256i k_1_hi = _mm256_load_si256 ((const __m256i * )(k_vals_1 + 16 )); // k-vals 16-31
1032+
1033+ const __m256i q8_1_all = _mm256_loadu_si256 ((const __m256i * )y1 -> qs );
1034+
1035+ const __m256i q8_1_lo = _mm256_cvtepi8_epi16 (_mm256_extracti128_si256 (q8_1_all , 0 )); // q-vals 0-15
1036+ const __m256i q8_1_hi = _mm256_cvtepi8_epi16 (_mm256_extracti128_si256 (q8_1_all , 1 )); // q-vals 16-31
1037+
1038+ const __m256i p_1_lo = _mm256_madd_epi16 (k_1_lo , q8_1_lo );
1039+ const __m256i p_1_hi = _mm256_madd_epi16 (k_1_hi , q8_1_hi );
1040+
1041+ const __m256i p_1_all = _mm256_add_epi32 (p_1_lo , p_1_hi ); // 8x s32
1042+
1043+ const __m256i k_2_lo = _mm256_load_si256 ((const __m256i * )(k_vals_2 + 0 ));
1044+ const __m256i k_2_hi = _mm256_load_si256 ((const __m256i * )(k_vals_2 + 16 ));
1045+ const __m256i q8_2_all = _mm256_loadu_si256 ((const __m256i * )y2 -> qs );
1046+ const __m256i q8_2_lo = _mm256_cvtepi8_epi16 (_mm256_extracti128_si256 (q8_2_all , 0 ));
1047+ const __m256i q8_2_hi = _mm256_cvtepi8_epi16 (_mm256_extracti128_si256 (q8_2_all , 1 ));
1048+ const __m256i p_2_lo = _mm256_madd_epi16 (k_2_lo , q8_2_lo );
1049+ const __m256i p_2_hi = _mm256_madd_epi16 (k_2_hi , q8_2_hi );
1050+ const __m256i p_2_all = _mm256_add_epi32 (p_2_lo , p_2_hi ); // 8x s32
1051+
1052+ const __m256 p_1_ps = _mm256_cvtepi32_ps (p_1_all );
1053+ const __m256 p_2_ps = _mm256_cvtepi32_ps (p_2_all );
1054+
1055+ // (d = d_y * d_x)
1056+ const float d1 = GGML_CPU_FP16_TO_FP32 (y1 -> d ) * GGML_E8M0_TO_FP32_HALF (x1 -> e );
1057+ const float d2 = GGML_CPU_FP16_TO_FP32 (y2 -> d ) * GGML_E8M0_TO_FP32_HALF (x2 -> e );
1058+
1059+ const __m256 d_1_ps = _mm256_set1_ps (d1 );
1060+ const __m256 d_2_ps = _mm256_set1_ps (d2 );
1061+
1062+ // Fused Multiply-Add (FMA): accum = (d * p) + accum
1063+ accum_ps = _mm256_fmadd_ps (d_1_ps , p_1_ps , accum_ps );
1064+ accum_ps = _mm256_fmadd_ps (d_2_ps , p_2_ps , accum_ps );
1065+ }
1066+
1067+ sumf = hsum_float_8 (accum_ps );
1068+ #endif
1069+
1070+ for (; ib < nb ; ++ ib ) {
1071+ const float d = GGML_CPU_FP16_TO_FP32 (y [ib ].d ) * GGML_E8M0_TO_FP32_HALF (x [ib ].e );
1072+
1073+ int sumi = 0 ;
1074+
1075+ for (int j = 0 ; j < QK_MXFP6_E2M3 / 4 ; ++ j ) {
1076+ const uint8_t * q3 = x [ib ].qs + 3 * j ;
1077+ const int8_t * q8 = y [ib ].qs + 4 * j ;
1078+
1079+ const uint8_t b0 = q3 [0 ];
1080+ const uint8_t b1 = q3 [1 ];
1081+ const uint8_t b2 = q3 [2 ];
1082+
1083+ const uint8_t v0_idx = b0 & 0x3F ;
1084+ const uint8_t v1_idx = (b0 >> 6 ) | ((b1 & 0x0F ) << 2 );
1085+ const uint8_t v2_idx = (b1 >> 4 ) | ((b2 & 0x03 ) << 4 );
1086+ const uint8_t v3_idx = b2 >> 2 ;
1087+
1088+ sumi += q8 [0 ] * kvalues_mxfp6_e2m3 [v0_idx ];
1089+ sumi += q8 [1 ] * kvalues_mxfp6_e2m3 [v1_idx ];
1090+ sumi += q8 [2 ] * kvalues_mxfp6_e2m3 [v2_idx ];
1091+ sumi += q8 [3 ] * kvalues_mxfp6_e2m3 [v3_idx ];
1092+ }
1093+ sumf += d * sumi ;
1094+ }
1095+
1096+ * s = sumf ;
1097+ }
1098+
1099+
9721100void ggml_vec_dot_q5_0_q8_0 (int n , float * GGML_RESTRICT s , size_t bs , const void * GGML_RESTRICT vx , size_t bx , const void * GGML_RESTRICT vy , size_t by , int nrc ) {
9731101 const int qk = QK8_0 ;
9741102 const int nb = n / qk ;
0 commit comments