55//
66
77#include " iqk_mmvq.cuh"
8+ #include " iqk_cuda_common.h"
89
910typedef void (*vec_dot_q_cuda_t )(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs, float *);
1011
@@ -785,77 +786,6 @@ __device__ __forceinline__ void vec_dot_iq6_k_q8_1(
785786 *result += d6 * (__low2float (bq8_1[2 *(i4/2 )+0 ].ds ) * sumi1 * bq6->scales [4 *(i4/2 )+(i4%2 )] + __low2float (bq8_1[2 *(i4/2 )+1 ].ds ) * sumi2 * bq6->scales [4 *(i4/2 )+(i4%2 )+2 ]);
786787}
787788
788- static const __device__ uint32_t iq2k_table[512 ] = {
789- 0xe1e1e1e1 , 0xe1e1e1f3 , 0xe1e1e101 , 0xe1e1e111 , 0xe1e1f3e1 , 0xe1e1f3f3 , 0xe1e1f301 , 0xe1e1f311 ,
790- 0xe1e101e1 , 0xe1e101f3 , 0xe1e10101 , 0xe1e10111 , 0xe1e111e1 , 0xe1e111f3 , 0xe1e11101 , 0xe1e11111 ,
791- 0xe1f3e1e1 , 0xe1f3e1f3 , 0xe1f3e101 , 0xe1f3e111 , 0xe1f3f3e1 , 0xe1f3f3f3 , 0xe1f3f301 , 0xe1f3f311 ,
792- 0xe1f301e1 , 0xe1f301f3 , 0xe1f30101 , 0xe1f30111 , 0xe1f311e1 , 0xe1f311f3 , 0xe1f31101 , 0xe1f31111 ,
793- 0xe101e1e1 , 0xe101e1f3 , 0xe101e101 , 0xe101e111 , 0xe101f3e1 , 0xe101f3f3 , 0xe101f301 , 0xe101f311 ,
794- 0xe10101e1 , 0xe10101f3 , 0xe1010101 , 0xe1010111 , 0xe10111e1 , 0xe10111f3 , 0xe1011101 , 0xe1011111 ,
795- 0xe111e1e1 , 0xe111e1f3 , 0xe111e101 , 0xe111e111 , 0xe111f3e1 , 0xe111f3f3 , 0xe111f301 , 0xe111f311 ,
796- 0xe11101e1 , 0xe11101f3 , 0xe1110101 , 0xe1110111 , 0xe11111e1 , 0xe11111f3 , 0xe1111101 , 0xe1111111 ,
797- 0xf3e1e1e1 , 0xf3e1e1f3 , 0xf3e1e101 , 0xf3e1e111 , 0xf3e1f3e1 , 0xf3e1f3f3 , 0xf3e1f301 , 0xf3e1f311 ,
798- 0xf3e101e1 , 0xf3e101f3 , 0xf3e10101 , 0xf3e10111 , 0xf3e111e1 , 0xf3e111f3 , 0xf3e11101 , 0xf3e11111 ,
799- 0xf3f3e1e1 , 0xf3f3e1f3 , 0xf3f3e101 , 0xf3f3e111 , 0xf3f3f3e1 , 0xf3f3f3f3 , 0xf3f3f301 , 0xf3f3f311 ,
800- 0xf3f301e1 , 0xf3f301f3 , 0xf3f30101 , 0xf3f30111 , 0xf3f311e1 , 0xf3f311f3 , 0xf3f31101 , 0xf3f31111 ,
801- 0xf301e1e1 , 0xf301e1f3 , 0xf301e101 , 0xf301e111 , 0xf301f3e1 , 0xf301f3f3 , 0xf301f301 , 0xf301f311 ,
802- 0xf30101e1 , 0xf30101f3 , 0xf3010101 , 0xf3010111 , 0xf30111e1 , 0xf30111f3 , 0xf3011101 , 0xf3011111 ,
803- 0xf311e1e1 , 0xf311e1f3 , 0xf311e101 , 0xf311e111 , 0xf311f3e1 , 0xf311f3f3 , 0xf311f301 , 0xf311f311 ,
804- 0xf31101e1 , 0xf31101f3 , 0xf3110101 , 0xf3110111 , 0xf31111e1 , 0xf31111f3 , 0xf3111101 , 0xf3111111 ,
805- 0x01e1e1e1 , 0x01e1e1f3 , 0x01e1e101 , 0x01e1e111 , 0x01e1f3e1 , 0x01e1f3f3 , 0x01e1f301 , 0x01e1f311 ,
806- 0x01e101e1 , 0x01e101f3 , 0x01e10101 , 0x01e10111 , 0x01e111e1 , 0x01e111f3 , 0x01e11101 , 0x01e11111 ,
807- 0x01f3e1e1 , 0x01f3e1f3 , 0x01f3e101 , 0x01f3e111 , 0x01f3f3e1 , 0x01f3f3f3 , 0x01f3f301 , 0x01f3f311 ,
808- 0x01f301e1 , 0x01f301f3 , 0x01f30101 , 0x01f30111 , 0x01f311e1 , 0x01f311f3 , 0x01f31101 , 0x01f31111 ,
809- 0x0101e1e1 , 0x0101e1f3 , 0x0101e101 , 0x0101e111 , 0x0101f3e1 , 0x0101f3f3 , 0x0101f301 , 0x0101f311 ,
810- 0x010101e1 , 0x010101f3 , 0x01010101 , 0x01010111 , 0x010111e1 , 0x010111f3 , 0x01011101 , 0x01011111 ,
811- 0x0111e1e1 , 0x0111e1f3 , 0x0111e101 , 0x0111e111 , 0x0111f3e1 , 0x0111f3f3 , 0x0111f301 , 0x0111f311 ,
812- 0x011101e1 , 0x011101f3 , 0x01110101 , 0x01110111 , 0x011111e1 , 0x011111f3 , 0x01111101 , 0x01111111 ,
813- 0x11e1e1e1 , 0x11e1e1f3 , 0x11e1e101 , 0x11e1e111 , 0x11e1f3e1 , 0x11e1f3f3 , 0x11e1f301 , 0x11e1f311 ,
814- 0x11e101e1 , 0x11e101f3 , 0x11e10101 , 0x11e10111 , 0x11e111e1 , 0x11e111f3 , 0x11e11101 , 0x11e11111 ,
815- 0x11f3e1e1 , 0x11f3e1f3 , 0x11f3e101 , 0x11f3e111 , 0x11f3f3e1 , 0x11f3f3f3 , 0x11f3f301 , 0x11f3f311 ,
816- 0x11f301e1 , 0x11f301f3 , 0x11f30101 , 0x11f30111 , 0x11f311e1 , 0x11f311f3 , 0x11f31101 , 0x11f31111 ,
817- 0x1101e1e1 , 0x1101e1f3 , 0x1101e101 , 0x1101e111 , 0x1101f3e1 , 0x1101f3f3 , 0x1101f301 , 0x1101f311 ,
818- 0x110101e1 , 0x110101f3 , 0x11010101 , 0x11010111 , 0x110111e1 , 0x110111f3 , 0x11011101 , 0x11011111 ,
819- 0x1111e1e1 , 0x1111e1f3 , 0x1111e101 , 0x1111e111 , 0x1111f3e1 , 0x1111f3f3 , 0x1111f301 , 0x1111f311 ,
820- 0x111101e1 , 0x111101f3 , 0x11110101 , 0x11110111 , 0x111111e1 , 0x111111f3 , 0x11111101 , 0x11111111 ,
821- 0xe6e6e6e6 , 0xe6e6e6f8 , 0xe6e6e606 , 0xe6e6e616 , 0xe6e6f8e6 , 0xe6e6f8f8 , 0xe6e6f806 , 0xe6e6f816 ,
822- 0xe6e606e6 , 0xe6e606f8 , 0xe6e60606 , 0xe6e60616 , 0xe6e616e6 , 0xe6e616f8 , 0xe6e61606 , 0xe6e61616 ,
823- 0xe6f8e6e6 , 0xe6f8e6f8 , 0xe6f8e606 , 0xe6f8e616 , 0xe6f8f8e6 , 0xe6f8f8f8 , 0xe6f8f806 , 0xe6f8f816 ,
824- 0xe6f806e6 , 0xe6f806f8 , 0xe6f80606 , 0xe6f80616 , 0xe6f816e6 , 0xe6f816f8 , 0xe6f81606 , 0xe6f81616 ,
825- 0xe606e6e6 , 0xe606e6f8 , 0xe606e606 , 0xe606e616 , 0xe606f8e6 , 0xe606f8f8 , 0xe606f806 , 0xe606f816 ,
826- 0xe60606e6 , 0xe60606f8 , 0xe6060606 , 0xe6060616 , 0xe60616e6 , 0xe60616f8 , 0xe6061606 , 0xe6061616 ,
827- 0xe616e6e6 , 0xe616e6f8 , 0xe616e606 , 0xe616e616 , 0xe616f8e6 , 0xe616f8f8 , 0xe616f806 , 0xe616f816 ,
828- 0xe61606e6 , 0xe61606f8 , 0xe6160606 , 0xe6160616 , 0xe61616e6 , 0xe61616f8 , 0xe6161606 , 0xe6161616 ,
829- 0xf8e6e6e6 , 0xf8e6e6f8 , 0xf8e6e606 , 0xf8e6e616 , 0xf8e6f8e6 , 0xf8e6f8f8 , 0xf8e6f806 , 0xf8e6f816 ,
830- 0xf8e606e6 , 0xf8e606f8 , 0xf8e60606 , 0xf8e60616 , 0xf8e616e6 , 0xf8e616f8 , 0xf8e61606 , 0xf8e61616 ,
831- 0xf8f8e6e6 , 0xf8f8e6f8 , 0xf8f8e606 , 0xf8f8e616 , 0xf8f8f8e6 , 0xf8f8f8f8 , 0xf8f8f806 , 0xf8f8f816 ,
832- 0xf8f806e6 , 0xf8f806f8 , 0xf8f80606 , 0xf8f80616 , 0xf8f816e6 , 0xf8f816f8 , 0xf8f81606 , 0xf8f81616 ,
833- 0xf806e6e6 , 0xf806e6f8 , 0xf806e606 , 0xf806e616 , 0xf806f8e6 , 0xf806f8f8 , 0xf806f806 , 0xf806f816 ,
834- 0xf80606e6 , 0xf80606f8 , 0xf8060606 , 0xf8060616 , 0xf80616e6 , 0xf80616f8 , 0xf8061606 , 0xf8061616 ,
835- 0xf816e6e6 , 0xf816e6f8 , 0xf816e606 , 0xf816e616 , 0xf816f8e6 , 0xf816f8f8 , 0xf816f806 , 0xf816f816 ,
836- 0xf81606e6 , 0xf81606f8 , 0xf8160606 , 0xf8160616 , 0xf81616e6 , 0xf81616f8 , 0xf8161606 , 0xf8161616 ,
837- 0x06e6e6e6 , 0x06e6e6f8 , 0x06e6e606 , 0x06e6e616 , 0x06e6f8e6 , 0x06e6f8f8 , 0x06e6f806 , 0x06e6f816 ,
838- 0x06e606e6 , 0x06e606f8 , 0x06e60606 , 0x06e60616 , 0x06e616e6 , 0x06e616f8 , 0x06e61606 , 0x06e61616 ,
839- 0x06f8e6e6 , 0x06f8e6f8 , 0x06f8e606 , 0x06f8e616 , 0x06f8f8e6 , 0x06f8f8f8 , 0x06f8f806 , 0x06f8f816 ,
840- 0x06f806e6 , 0x06f806f8 , 0x06f80606 , 0x06f80616 , 0x06f816e6 , 0x06f816f8 , 0x06f81606 , 0x06f81616 ,
841- 0x0606e6e6 , 0x0606e6f8 , 0x0606e606 , 0x0606e616 , 0x0606f8e6 , 0x0606f8f8 , 0x0606f806 , 0x0606f816 ,
842- 0x060606e6 , 0x060606f8 , 0x06060606 , 0x06060616 , 0x060616e6 , 0x060616f8 , 0x06061606 , 0x06061616 ,
843- 0x0616e6e6 , 0x0616e6f8 , 0x0616e606 , 0x0616e616 , 0x0616f8e6 , 0x0616f8f8 , 0x0616f806 , 0x0616f816 ,
844- 0x061606e6 , 0x061606f8 , 0x06160606 , 0x06160616 , 0x061616e6 , 0x061616f8 , 0x06161606 , 0x06161616 ,
845- 0x16e6e6e6 , 0x16e6e6f8 , 0x16e6e606 , 0x16e6e616 , 0x16e6f8e6 , 0x16e6f8f8 , 0x16e6f806 , 0x16e6f816 ,
846- 0x16e606e6 , 0x16e606f8 , 0x16e60606 , 0x16e60616 , 0x16e616e6 , 0x16e616f8 , 0x16e61606 , 0x16e61616 ,
847- 0x16f8e6e6 , 0x16f8e6f8 , 0x16f8e606 , 0x16f8e616 , 0x16f8f8e6 , 0x16f8f8f8 , 0x16f8f806 , 0x16f8f816 ,
848- 0x16f806e6 , 0x16f806f8 , 0x16f80606 , 0x16f80616 , 0x16f816e6 , 0x16f816f8 , 0x16f81606 , 0x16f81616 ,
849- 0x1606e6e6 , 0x1606e6f8 , 0x1606e606 , 0x1606e616 , 0x1606f8e6 , 0x1606f8f8 , 0x1606f806 , 0x1606f816 ,
850- 0x160606e6 , 0x160606f8 , 0x16060606 , 0x16060616 , 0x160616e6 , 0x160616f8 , 0x16061606 , 0x16061616 ,
851- 0x1616e6e6 , 0x1616e6f8 , 0x1616e606 , 0x1616e616 , 0x1616f8e6 , 0x1616f8f8 , 0x1616f806 , 0x1616f816 ,
852- 0x161606e6 , 0x161606f8 , 0x16160606 , 0x16160616 , 0x161616e6 , 0x161616f8 , 0x16161606 , 0x16161616 ,
853- };
854-
855- __device__ __forceinline__ int int_from_table_4 (const uint8_t * a8, const int * values) {
856- return values[a8[0 ] | (a8[1 ] << 2 ) | (a8[2 ] << 4 ) | (a8[3 ] << 6 )];
857- }
858-
859789#define VDR_IQ2_K_Q8_1_MMVQ 4
860790#define VDR_IQ2_K_Q8_1_MMQ 4
861791
@@ -881,7 +811,6 @@ __device__ __forceinline__ void vec_dot_iq2_k_q8_1(
881811 uint32_t val1 = q2[0 ], val2 = q2[1 ];
882812
883813 uint32_t aux32[2 ];
884- const uint8_t * a8 = (const uint8_t *)&aux32;
885814 int v1, v2;
886815
887816 // Block of 16: (32*(4*(i4/4)+k)+8*(i4%4))/16 = 8*(i4/4) + 2*k + (i4%4)/2
@@ -892,23 +821,23 @@ __device__ __forceinline__ void vec_dot_iq2_k_q8_1(
892821 const int8_t * s8 = (const int8_t *)&s32;
893822
894823 aux32[0 ] = ((val1 >> 0 ) & 0x03030303 ); aux32[1 ] = ((val2 >> 0 ) & 0x03030303 ); values = all_values + ((extra & 0x01 ) << 8 );
895- v1 = int_from_table_4 (a8 + 0 , values);
896- v2 = int_from_table_4 (a8 + 4 , values);
824+ v1 = int_from_table_4 (aux32[ 0 ] , values);
825+ v2 = int_from_table_4 (aux32[ 1 ] , values);
897826 int sumi1 = ggml_cuda_dp4a (v2, q8_1[1 ], ggml_cuda_dp4a (v1, q8_1[0 ], 0 )) * s8[0 ];
898827
899828 aux32[0 ] = ((val1 >> 2 ) & 0x03030303 ); aux32[1 ] = ((val2 >> 2 ) & 0x03030303 ); values = all_values + ((extra & 0x04 ) << 6 );
900- v1 = int_from_table_4 (a8 + 0 , values);
901- v2 = int_from_table_4 (a8 + 4 , values);
829+ v1 = int_from_table_4 (aux32[ 0 ] , values);
830+ v2 = int_from_table_4 (aux32[ 1 ] , values);
902831 int sumi2 = ggml_cuda_dp4a (v2, q8_2[1 ], ggml_cuda_dp4a (v1, q8_2[0 ], 0 )) * s8[1 ];
903832
904833 aux32[0 ] = ((val1 >> 4 ) & 0x03030303 ); aux32[1 ] = ((val2 >> 4 ) & 0x03030303 ); values = all_values + ((extra & 0x10 ) << 4 );
905- v1 = int_from_table_4 (a8 + 0 , values);
906- v2 = int_from_table_4 (a8 + 4 , values);
834+ v1 = int_from_table_4 (aux32[ 0 ] , values);
835+ v2 = int_from_table_4 (aux32[ 1 ] , values);
907836 int sumi3 = ggml_cuda_dp4a (v2, q8_3[1 ], ggml_cuda_dp4a (v1, q8_3[0 ], 0 )) * s8[2 ];
908837
909838 aux32[0 ] = ((val1 >> 6 ) & 0x03030303 ); aux32[1 ] = ((val2 >> 6 ) & 0x03030303 ); values = all_values + ((extra & 0x40 ) << 2 );
910- v1 = int_from_table_4 (a8 + 0 , values);
911- v2 = int_from_table_4 (a8 + 4 , values);
839+ v1 = int_from_table_4 (aux32[ 0 ] , values);
840+ v2 = int_from_table_4 (aux32[ 1 ] , values);
912841 int sumi4 = ggml_cuda_dp4a (v2, q8_4[1 ], ggml_cuda_dp4a (v1, q8_4[0 ], 0 )) * s8[3 ];
913842
914843 *result += __half2float (bq2->d ) * (__low2float (bq8_1[4 *(i4/4 )+0 ].ds ) * sumi1
@@ -941,7 +870,6 @@ __device__ __forceinline__ void vec_dot_iq2_ks_q8_1(
941870 uint32_t val1 = q2[0 ] | (q2[1 ] << 16 ), val2 = q2[2 ] | (q2[3 ] << 16 );
942871
943872 uint32_t aux32[2 ];
944- const uint8_t * a8 = (const uint8_t *)&aux32;
945873 int v1, v2;
946874
947875 int32_t scales32;
@@ -954,23 +882,23 @@ __device__ __forceinline__ void vec_dot_iq2_ks_q8_1(
954882 s8[3 ] += ((extra >> 7 ) & 0x10 );
955883
956884 aux32[0 ] = ((val1 >> 0 ) & 0x03030303 ); aux32[1 ] = ((val2 >> 0 ) & 0x03030303 ); values = all_values + ((extra & 0x01 ) << 8 );
957- v1 = int_from_table_4 (a8 + 0 , values);
958- v2 = int_from_table_4 (a8 + 4 , values);
885+ v1 = int_from_table_4 (aux32[ 0 ] , values);
886+ v2 = int_from_table_4 (aux32[ 1 ] , values);
959887 int sumi1 = ggml_cuda_dp4a (v2, q8_1[1 ], ggml_cuda_dp4a (v1, q8_1[0 ], 0 )) * s8[0 ];
960888
961889 aux32[0 ] = ((val1 >> 2 ) & 0x03030303 ); aux32[1 ] = ((val2 >> 2 ) & 0x03030303 ); values = all_values + ((extra & 0x02 ) << 7 );
962- v1 = int_from_table_4 (a8 + 0 , values);
963- v2 = int_from_table_4 (a8 + 4 , values);
890+ v1 = int_from_table_4 (aux32[ 0 ] , values);
891+ v2 = int_from_table_4 (aux32[ 1 ] , values);
964892 int sumi2 = ggml_cuda_dp4a (v2, q8_2[1 ], ggml_cuda_dp4a (v1, q8_2[0 ], 0 )) * s8[2 ];
965893
966894 aux32[0 ] = ((val1 >> 4 ) & 0x03030303 ); aux32[1 ] = ((val2 >> 4 ) & 0x03030303 ); values = all_values + ((extra & 0x04 ) << 6 );
967- v1 = int_from_table_4 (a8 + 0 , values);
968- v2 = int_from_table_4 (a8 + 4 , values);
895+ v1 = int_from_table_4 (aux32[ 0 ] , values);
896+ v2 = int_from_table_4 (aux32[ 1 ] , values);
969897 int sumi3 = ggml_cuda_dp4a (v2, q8_3[1 ], ggml_cuda_dp4a (v1, q8_3[0 ], 0 )) * s8[1 ];
970898
971899 aux32[0 ] = ((val1 >> 6 ) & 0x03030303 ); aux32[1 ] = ((val2 >> 6 ) & 0x03030303 ); values = all_values + ((extra & 0x08 ) << 5 );
972- v1 = int_from_table_4 (a8 + 0 , values);
973- v2 = int_from_table_4 (a8 + 4 , values);
900+ v1 = int_from_table_4 (aux32[ 0 ] , values);
901+ v2 = int_from_table_4 (aux32[ 1 ] , values);
974902 int sumi4 = ggml_cuda_dp4a (v2, q8_4[1 ], ggml_cuda_dp4a (v1, q8_4[0 ], 0 )) * s8[3 ];
975903
976904 *result += scale * (__low2float (bq8_1[4 *(i4/4 )+0 ].ds ) * sumi1
@@ -1000,20 +928,19 @@ __device__ __forceinline__ void vec_dot_iq2_k_r4_q8_1(
1000928 int2 val1;
1001929 const int * q2 = (const int *)bq2->qs + 8 *ib32 + 4 *is;
1002930 int aux32[2 ];
1003- const uint8_t * aux8 = (const uint8_t *)aux32;
1004931#pragma unroll
1005932 for (int i = 0 ; i < 4 ; ++i) {
1006933 auto values1 = all_values + (((bq2->extra [i+4 *is] >> ib32) & 1 ) << 8 );
1007934 int sumi1 = 0 ;
1008935 aux32[0 ] = ((q2[i] >> 0 ) & 0x03030303 );
1009936 aux32[1 ] = ((q2[i] >> 2 ) & 0x03030303 );
1010- val1.x = int_from_table_4 (aux8+ 0 , values1);
1011- val1.y = int_from_table_4 (aux8+ 4 , values1);
937+ val1.x = int_from_table_4 (aux32[ 0 ] , values1);
938+ val1.y = int_from_table_4 (aux32[ 1 ] , values1);
1012939 sumi1 = ggml_cuda_dp4a (val1.x , q8[0 ], ggml_cuda_dp4a (val1.y , q8[1 ], sumi1));
1013940 aux32[0 ] = ((q2[i] >> 4 ) & 0x03030303 );
1014941 aux32[1 ] = ((q2[i] >> 6 ) & 0x03030303 );
1015- val1.x = int_from_table_4 (aux8+ 0 , values1);
1016- val1.y = int_from_table_4 (aux8+ 4 , values1);
942+ val1.x = int_from_table_4 (aux32[ 0 ] , values1);
943+ val1.y = int_from_table_4 (aux32[ 1 ] , values1);
1017944 sumi1 = ggml_cuda_dp4a (val1.x , q8[2 ], ggml_cuda_dp4a (val1.y , q8[3 ], sumi1));
1018945 const float d = __half2float (bq2->d [i]) * d8;
1019946 result[i] += d * sumi1 * s8[i];
@@ -1114,7 +1041,6 @@ __device__ __forceinline__ void vec_dot_iq3_ks_q8_1(
11141041 const int ib128 = iqs/4 ; // 0 or 1. 0 works on quants 0...127, 1 on quants 128...255
11151042 // Each thread processes 8 quants in each of the 4 32-blocks
11161043 const int il8 = iqs%4 ; // 0...3. 0 works on quants 0...7, 1 on quants 8...15, 2 on 16...23, 3 on 24...31
1117- const int shift = 4 *(il8/2 );
11181044
11191045 const uint16_t * ql = (const uint16_t *)bq3->qs + 16 *ib128 + 4 *il8;
11201046 const uint16_t * qh = (const uint16_t *)bq3->qh + 4 *il8;
0 commit comments