@@ -844,37 +844,129 @@ void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
844844
845845void ggml_vec_dot_mxfp6_e3m2_q8_0 (int n , float * GGML_RESTRICT s , size_t bs , const void * GGML_RESTRICT vx , size_t bx , const void * GGML_RESTRICT vy , size_t by , int nrc ) {
846846 assert (nrc == 1 );
847- UNUSED (nrc );
848- UNUSED (bx );
849- UNUSED (by );
850- UNUSED (bs );
851- assert (n % QK_MXFP6_E3M2 == 0 );
852- static_assert (QK_MXFP6_E3M2 == QK8_0 , "QK_MXFP6_E3M2 and QK8_0 must be the same" );
847+ UNUSED (nrc );
848+ UNUSED (bx );
849+ UNUSED (by );
850+ UNUSED (bs );
851+ assert (n % QK_MXFP6_E3M2 == 0 );
852+ static_assert (QK_MXFP6_E3M2 == QK8_0 , "QK_MXFP6_E3M2 and QK8_0 must be the same" );
853+ assert (QK_MXFP6_E3M2 == 32 );
854+
855+ const block_mxfp6_e3m2 * GGML_RESTRICT x = vx ;
856+ const block_q8_0 * GGML_RESTRICT y = vy ;
857+
858+ const int nb = n / QK_MXFP6_E3M2 ;
859+
860+ int ib = 0 ;
861+ float sumf = 0 ;
862+
863+ #if defined __AVX2__
864+ __m256 accum_ps = _mm256_setzero_ps ();
865+
866+ for (; ib + 1 < nb ; ib += 2 ) {
867+ const block_mxfp6_e3m2 * x1 = & x [ib + 0 ];
868+ const block_q8_0 * y1 = & y [ib + 0 ];
869+
870+ const block_mxfp6_e3m2 * x2 = & x [ib + 1 ];
871+ const block_q8_0 * y2 = & y [ib + 1 ];
872+
873+ alignas(32 ) int16_t k_vals_1 [32 ];
874+ {
875+ const uint8_t * q3 = x1 -> qs ;
876+ for (int j = 0 ; j < 8 ; ++ j ) {
877+ const uint8_t b0 = q3 [0 ];
878+ const uint8_t b1 = q3 [1 ];
879+ const uint8_t b2 = q3 [2 ];
880+ k_vals_1 [4 * j + 0 ] = kvalues_mxfp6_e3m2 [b0 & 0x3F ];
881+ k_vals_1 [4 * j + 1 ] = kvalues_mxfp6_e3m2 [(b0 >> 6 ) | ((b1 & 0x0F ) << 2 )];
882+ k_vals_1 [4 * j + 2 ] = kvalues_mxfp6_e3m2 [(b1 >> 4 ) | ((b2 & 0x03 ) << 4 )];
883+ k_vals_1 [4 * j + 3 ] = kvalues_mxfp6_e3m2 [b2 >> 2 ];
884+ q3 += 3 ;
885+ }
886+ }
887+
888+ alignas(32 ) int16_t k_vals_2 [32 ];
889+ {
890+ const uint8_t * q3 = x2 -> qs ;
891+ for (int j = 0 ; j < 8 ; ++ j ) {
892+ const uint8_t b0 = q3 [0 ];
893+ const uint8_t b1 = q3 [1 ];
894+ const uint8_t b2 = q3 [2 ];
895+ k_vals_2 [4 * j + 0 ] = kvalues_mxfp6_e3m2 [b0 & 0x3F ];
896+ k_vals_2 [4 * j + 1 ] = kvalues_mxfp6_e3m2 [(b0 >> 6 ) | ((b1 & 0x0F ) << 2 )];
897+ k_vals_2 [4 * j + 2 ] = kvalues_mxfp6_e3m2 [(b1 >> 4 ) | ((b2 & 0x03 ) << 4 )];
898+ k_vals_2 [4 * j + 3 ] = kvalues_mxfp6_e3m2 [b2 >> 2 ];
899+ q3 += 3 ;
900+ }
901+ }
902+
903+ const __m256i k_1_lo = _mm256_load_si256 ((const __m256i * )(k_vals_1 + 0 )); // k-vals 0-15
904+ const __m256i k_1_hi = _mm256_load_si256 ((const __m256i * )(k_vals_1 + 16 )); // k-vals 16-31
905+
906+ const __m256i q8_1_all = _mm256_loadu_si256 ((const __m256i * )y1 -> qs );
907+
908+ const __m256i q8_1_lo = _mm256_cvtepi8_epi16 (_mm256_extracti128_si256 (q8_1_all , 0 )); // q-vals 0-15
909+ const __m256i q8_1_hi = _mm256_cvtepi8_epi16 (_mm256_extracti128_si256 (q8_1_all , 1 )); // q-vals 16-31
910+
911+ const __m256i p_1_lo = _mm256_madd_epi16 (k_1_lo , q8_1_lo );
912+ const __m256i p_1_hi = _mm256_madd_epi16 (k_1_hi , q8_1_hi );
913+
914+ const __m256i p_1_all = _mm256_add_epi32 (p_1_lo , p_1_hi ); // 8x s32
915+
916+ const __m256i k_2_lo = _mm256_load_si256 ((const __m256i * )(k_vals_2 + 0 ));
917+ const __m256i k_2_hi = _mm256_load_si256 ((const __m256i * )(k_vals_2 + 16 ));
918+ const __m256i q8_2_all = _mm256_loadu_si256 ((const __m256i * )y2 -> qs );
919+ const __m256i q8_2_lo = _mm256_cvtepi8_epi16 (_mm256_extracti128_si256 (q8_2_all , 0 ));
920+ const __m256i q8_2_hi = _mm256_cvtepi8_epi16 (_mm256_extracti128_si256 (q8_2_all , 1 ));
921+ const __m256i p_2_lo = _mm256_madd_epi16 (k_2_lo , q8_2_lo );
922+ const __m256i p_2_hi = _mm256_madd_epi16 (k_2_hi , q8_2_hi );
923+ const __m256i p_2_all = _mm256_add_epi32 (p_2_lo , p_2_hi ); // 8x s32
924+
925+ const __m256 p_1_ps = _mm256_cvtepi32_ps (p_1_all );
926+ const __m256 p_2_ps = _mm256_cvtepi32_ps (p_2_all );
927+
928+ // (d = d_y * d_x)
929+ const float d1 = GGML_CPU_FP16_TO_FP32 (y1 -> d ) * GGML_E8M0_TO_FP32_HALF (x1 -> e );
930+ const float d2 = GGML_CPU_FP16_TO_FP32 (y2 -> d ) * GGML_E8M0_TO_FP32_HALF (x2 -> e );
931+
932+ const __m256 d_1_ps = _mm256_set1_ps (d1 );
933+ const __m256 d_2_ps = _mm256_set1_ps (d2 );
934+
935+ // Fused Multiply-Add (FMA): accum = (d * p) + accum
936+ accum_ps = _mm256_fmadd_ps (d_1_ps , p_1_ps , accum_ps );
937+ accum_ps = _mm256_fmadd_ps (d_2_ps , p_2_ps , accum_ps );
938+ }
853939
854- const block_mxfp6_e3m2 * GGML_RESTRICT x = vx ;
855- const block_q8_0 * GGML_RESTRICT y = vy ;
940+ sumf = hsum_float_8 ( accum_ps ) ;
941+ #endif
856942
857- const int nb = n / QK_MXFP6_E3M2 ;
943+ for (; ib < nb ; ++ ib ) {
944+ const float d = GGML_CPU_FP16_TO_FP32 (y [ib ].d ) * GGML_E8M0_TO_FP32_HALF (x [ib ].e );
858945
859- int ib = 0 ;
860- float sumf = 0 ;
946+ int sumi = 0 ;
861947
862- for (; ib < nb ; ++ ib ) {
863- const float d = GGML_CPU_FP16_TO_FP32 (y [ib ].d )* GGML_E8M0_TO_FP32_HALF (x [ib ].e );
864- int sumi1 = 0 ;
865- int sumi2 = 0 ;
866- int sumi3 = 0 ;
867- int sumi4 = 0 ;
868- // Q8_0 (y) * MXFP6 (block_size = 32)
869- for (int j = 0 ; j < QK_MXFP6_E3M2 /4 ; ++ j ) {
870- sumi1 += y [ib ].qs [j + 0 ] * kvalues_mxfp6_e3m2 [ x [ib ].qs [3 * j ] & 0x3f ];
871- sumi2 += y [ib ].qs [j + 1 * QK_MXFP6_E3M2 /4 ] * kvalues_mxfp6_e3m2 [(x [ib ].qs [3 * j ] >> 6 ) | ((x [ib ].qs [3 * j + 1 ] & 0x0F ) << 2 )];
872- sumi3 += y [ib ].qs [j + 2 * QK_MXFP6_E3M2 /4 ] * kvalues_mxfp6_e3m2 [(x [ib ].qs [3 * j + 1 ] >> 4 ) | ((x [ib ].qs [3 * j + 2 ] & 0x03 ) << 4 )];
873- sumi4 += y [ib ].qs [j + 3 * QK_MXFP6_E3M2 /4 ] * kvalues_mxfp6_e3m2 [ x [ib ].qs [3 * j + 2 ] >> 2 ];
948+ for (int j = 0 ; j < QK_MXFP6_E3M2 / 4 ; ++ j ) {
949+ const uint8_t * q3 = x [ib ].qs + 3 * j ;
950+ const int8_t * q8 = y [ib ].qs + 4 * j ;
951+
952+ const uint8_t b0 = q3 [0 ];
953+ const uint8_t b1 = q3 [1 ];
954+ const uint8_t b2 = q3 [2 ];
955+
956+ const uint8_t v0_idx = b0 & 0x3F ;
957+ const uint8_t v1_idx = (b0 >> 6 ) | ((b1 & 0x0F ) << 2 );
958+ const uint8_t v2_idx = (b1 >> 4 ) | ((b2 & 0x03 ) << 4 );
959+ const uint8_t v3_idx = b2 >> 2 ;
960+
961+ sumi += q8 [0 ] * kvalues_mxfp6_e3m2 [v0_idx ];
962+ sumi += q8 [1 ] * kvalues_mxfp6_e3m2 [v1_idx ];
963+ sumi += q8 [2 ] * kvalues_mxfp6_e3m2 [v2_idx ];
964+ sumi += q8 [3 ] * kvalues_mxfp6_e3m2 [v3_idx ];
965+ }
966+ sumf += d * sumi ;
874967 }
875- sumf += d * (sumi1 + sumi2 + sumi3 + sumi4 );
876- }
877- * s = sumf ;
968+
969+ * s = sumf ;
878970}
879971
880972void ggml_vec_dot_q5_0_q8_0 (int n , float * GGML_RESTRICT s , size_t bs , const void * GGML_RESTRICT vx , size_t bx , const void * GGML_RESTRICT vy , size_t by , int nrc ) {
0 commit comments