@@ -849,6 +849,34 @@ __device__ __forceinline__ void vec_dot_iq2_k_q8_1(
849849 const uint32_t * q2 = (const uint32_t *)bq2->qs + 8 *(i4/4 ) + 2 *(i4%4 );
850850 const uint16_t extra = bq2->extra >> (8 *(i4/4 ) + (i4%4 )/2 );
851851
852+ const uint32_t * scales = (const uint32_t *)bq2->scales ;
853+ uint32_t s32 = __vsub4 ((scales[i4/4 ] >> 4 *(((i4%4 )/2 )%2 )) & 0x0f0f0f0f , 0x08080808 );
854+ const int8_t * s8 = (const int8_t *)&s32;
855+
856+ // Block of 16: (32*(4*(i4/4)+k)+8*(i4%4))/16 = 8*(i4/4) + 2*k + (i4%4)/2
857+ // -> scales_l[4*(i4/4) + k] >> 4*(((i4%4)/2)%2)
858+
859+ #ifdef __CUDA_ARCH__
860+ uint32_t extra32 = uint32_t (extra & 0xff ) * 0x01010101 ;
861+ uint32_t extra32_1 = (extra32 << 2 ) & 0x44444444 ;
862+ uint32_t extra32_2 = (extra32 << 0 ) & 0x44444444 ;
863+
864+ uint32_t val1, val2;
865+
866+ val1 = ((q2[0 ] >> 0 ) & 0x33333333 ) | extra32_1; val2 = ((q2[1 ] >> 0 ) & 0x33333333 ) | extra32_1;
867+ int2 v1 = get_int_from_table_8 (val1, iq2nl_values);
868+ int2 v2 = get_int_from_table_8 (val2, iq2nl_values);
869+ int sumi1 = ggml_cuda_dp4a (v2.x , q8_1[1 ], ggml_cuda_dp4a (v1.x , q8_1[0 ], 0 )) * s8[0 ];
870+ int sumi3 = ggml_cuda_dp4a (v2.y , q8_3[1 ], ggml_cuda_dp4a (v1.y , q8_3[0 ], 0 )) * s8[2 ];
871+
872+ val1 = ((q2[0 ] >> 2 ) & 0x33333333 ) | extra32_2; val2 = ((q2[1 ] >> 2 ) & 0x33333333 ) | extra32_2;
873+ v1 = get_int_from_table_8 (val1, iq2nl_values);
874+ v2 = get_int_from_table_8 (val2, iq2nl_values);
875+ int sumi2 = ggml_cuda_dp4a (v2.x , q8_2[1 ], ggml_cuda_dp4a (v1.x , q8_2[0 ], 0 )) * s8[1 ];
876+ int sumi4 = ggml_cuda_dp4a (v2.y , q8_4[1 ], ggml_cuda_dp4a (v1.y , q8_4[0 ], 0 )) * s8[3 ];
877+
878+ #else
879+
852880 const int * all_values = (const int *)iq2k_table;
853881 const int * values;
854882
@@ -857,13 +885,6 @@ __device__ __forceinline__ void vec_dot_iq2_k_q8_1(
857885 uint32_t aux32[2 ];
858886 int v1, v2;
859887
860- // Block of 16: (32*(4*(i4/4)+k)+8*(i4%4))/16 = 8*(i4/4) + 2*k + (i4%4)/2
861- // -> scales_l[4*(i4/4) + k] >> 4*(((i4%4)/2)%2)
862-
863- const uint32_t * scales = (const uint32_t *)bq2->scales ;
864- uint32_t s32 = __vsub4 ((scales[i4/4 ] >> 4 *(((i4%4 )/2 )%2 )) & 0x0f0f0f0f , 0x08080808 );
865- const int8_t * s8 = (const int8_t *)&s32;
866-
867888 aux32[0 ] = ((val1 >> 0 ) & 0x03030303 ); aux32[1 ] = ((val2 >> 0 ) & 0x03030303 ); values = all_values + ((extra & 0x01 ) << 8 );
868889 v1 = int_from_table_4 (aux32[0 ], values);
869890 v2 = int_from_table_4 (aux32[1 ], values);
@@ -883,6 +904,7 @@ __device__ __forceinline__ void vec_dot_iq2_k_q8_1(
883904 v1 = int_from_table_4 (aux32[0 ], values);
884905 v2 = int_from_table_4 (aux32[1 ], values);
885906 int sumi4 = ggml_cuda_dp4a (v2, q8_4[1 ], ggml_cuda_dp4a (v1, q8_4[0 ], 0 )) * s8[3 ];
907+ #endif
886908
887909 *result += __half2float (bq2->d ) * (__low2float (bq8_1[4 *(i4/4 )+0 ].ds ) * sumi1
888910 + __low2float (bq8_1[4 *(i4/4 )+1 ].ds ) * sumi2
@@ -908,14 +930,8 @@ __device__ __forceinline__ void vec_dot_iq2_ks_q8_1(
908930 const uint16_t * q2 = (const uint16_t *)bq2->qs + 16 *(i4/4 ) + 4 *(i4%4 );
909931 const uint16_t extra = bq2->extra >> 4 *(i4/4 );
910932
911- const int * all_values = (const int *)iq2k_table;
912- const int * values;
913-
914933 uint32_t val1 = q2[0 ] | (q2[1 ] << 16 ), val2 = q2[2 ] | (q2[3 ] << 16 );
915934
916- uint32_t aux32[2 ];
917- int v1, v2;
918-
919935 int32_t scales32;
920936 const uint16_t * scales16 = (const uint16_t *)bq2->scales ;
921937 scales32 = __vsub4 ((scales16[i4/4 ] | (scales16[i4/4 ] << 12 )) & 0x0f0f0f0f , 0x10101010 );
@@ -925,6 +941,35 @@ __device__ __forceinline__ void vec_dot_iq2_ks_q8_1(
925941 s8[2 ] += ((extra >> 5 ) & 0x10 );
926942 s8[3 ] += ((extra >> 7 ) & 0x10 );
927943
944+ #ifdef __CUDA_ARCH__
945+
946+ uint32_t extra32 = uint32_t (extra & 0xf ) * 0x01010101 ;
947+
948+ uint32_t this_extra = ((extra32 << 2 ) & 0x04040404 ) | ((extra32 << 4 ) & 0x40404040 );
949+ uint32_t idx1 = ((val1 >> 0 ) & 0x33333333 ) | this_extra;
950+ uint32_t idx2 = ((val2 >> 0 ) & 0x33333333 ) | this_extra;
951+ int2 v1 = get_int_from_table_8 (idx1, iq2nl_values);
952+ int2 v2 = get_int_from_table_8 (idx2, iq2nl_values);
953+
954+ int sumi1 = ggml_cuda_dp4a (v2.x , q8_1[1 ], ggml_cuda_dp4a (v1.x , q8_1[0 ], 0 )) * s8[0 ];
955+ int sumi3 = ggml_cuda_dp4a (v2.y , q8_3[1 ], ggml_cuda_dp4a (v1.y , q8_3[0 ], 0 )) * s8[1 ];
956+
957+ this_extra = ((extra32 << 1 ) & 0x04040404 ) | ((extra32 << 3 ) & 0x40404040 );
958+ idx1 = ((val1 >> 2 ) & 0x33333333 ) | this_extra;
959+ idx2 = ((val2 >> 2 ) & 0x33333333 ) | this_extra;
960+ v1 = get_int_from_table_8 (idx1, iq2nl_values);
961+ v2 = get_int_from_table_8 (idx2, iq2nl_values);
962+
963+ int sumi2 = ggml_cuda_dp4a (v2.x , q8_2[1 ], ggml_cuda_dp4a (v1.x , q8_2[0 ], 0 )) * s8[2 ];
964+ int sumi4 = ggml_cuda_dp4a (v2.y , q8_4[1 ], ggml_cuda_dp4a (v1.y , q8_4[0 ], 0 )) * s8[3 ];
965+
966+ #else
967+
968+ uint32_t aux32[2 ];
969+ int v1, v2;
970+ const int * all_values = (const int *)iq2k_table;
971+ const int * values;
972+
928973 aux32[0 ] = ((val1 >> 0 ) & 0x03030303 ); aux32[1 ] = ((val2 >> 0 ) & 0x03030303 ); values = all_values + ((extra & 0x01 ) << 8 );
929974 v1 = int_from_table_4 (aux32[0 ], values);
930975 v2 = int_from_table_4 (aux32[1 ], values);
@@ -944,6 +989,7 @@ __device__ __forceinline__ void vec_dot_iq2_ks_q8_1(
944989 v1 = int_from_table_4 (aux32[0 ], values);
945990 v2 = int_from_table_4 (aux32[1 ], values);
946991 int sumi4 = ggml_cuda_dp4a (v2, q8_4[1 ], ggml_cuda_dp4a (v1, q8_4[0 ], 0 )) * s8[3 ];
992+ #endif
947993
948994 *result += scale * (__low2float (bq8_1[4 *(i4/4 )+0 ].ds ) * sumi1
949995 + __low2float (bq8_1[4 *(i4/4 )+1 ].ds ) * sumi2
@@ -965,12 +1011,31 @@ __device__ __forceinline__ void vec_dot_iq2_k_r4_q8_1(
9651011 int is = ib16%2 ;
9661012 const int * scales_l = (const int *)bq2->scales ;
9671013
968- const int * all_values = (const int *)iq2k_table;
969-
9701014 int scales = __vsub4 (((scales_l[2 *(ib32%4 )+is] >> 4 *(ib32/4 )) & 0x0f0f0f0f ), 0x08080808 );
9711015 const int8_t * s8 = (const int8_t *)&scales;
972- int2 val1;
1016+
9731017 const int * q2 = (const int *)bq2->qs + 8 *ib32 + 4 *is;
1018+
1019+ #ifdef __CUDA_ARCH__
1020+
1021+ #pragma unroll
1022+ for (int i = 0 ; i < 4 ; ++i) {
1023+ uint32_t extra32 = uint32_t ((bq2->extra [i+4 *is] >> ib32) & 1 ) * 0x04040404 ;
1024+ extra32 |= (extra32 << 4 );
1025+ uint32_t val1 = ((q2[i] >> 0 ) & 0x33333333 ) | extra32;
1026+ uint32_t val2 = ((q2[i] >> 2 ) & 0x33333333 ) | extra32;
1027+ int2 v1 = get_int_from_table_8 (val1, iq2nl_values);
1028+ int2 v2 = get_int_from_table_8 (val2, iq2nl_values);
1029+ int sumi = 0 ;
1030+ sumi = ggml_cuda_dp4a (v1.x , q8[0 ], ggml_cuda_dp4a (v2.x , q8[1 ], sumi));
1031+ sumi = ggml_cuda_dp4a (v1.y , q8[2 ], ggml_cuda_dp4a (v2.y , q8[3 ], sumi));
1032+ const float d = __half2float (bq2->d [i]) * d8;
1033+ result[i] += d * sumi * s8[i];
1034+ }
1035+
1036+ #else
1037+ const int * all_values = (const int *)iq2k_table;
1038+ int2 val1;
9741039 int aux32[2 ];
9751040#pragma unroll
9761041 for (int i = 0 ; i < 4 ; ++i) {
@@ -989,6 +1054,7 @@ __device__ __forceinline__ void vec_dot_iq2_k_r4_q8_1(
9891054 const float d = __half2float (bq2->d [i]) * d8;
9901055 result[i] += d * sumi1 * s8[i];
9911056 }
1057+ #endif
9921058}
9931059
9941060#define VDR_IQ3_K_Q8_1_MMVQ 4
0 commit comments