@@ -776,28 +776,21 @@ __device__ __forceinline__ void vec_dot_iq3_k_r4_q8_1(
776776 // scales[1] = __vcmpeq4((scales_h[is] >> ib32) & 0x01010101, 0x01010101);
777777 // scales[0] = __vsub4(scales[0] ^ scales[1], scales[1]);
778778 const int8_t * s8 = (const int8_t *)scales;
779- int2 val1;
780- const int * q2 = (const int *)bq3->qs + 8 *ib32 + 4 *is;
781- const int * qh = (const int *)bq3->qh + 4 *ib32;
782- int aux32[2 ];
783- const uint8_t * aux8 = (const uint8_t *)aux32;
779+ const uint32_t * q2 = (const uint32_t *)bq3->qs + 8 *ib32 + 4 *is;
780+ const uint32_t * qh = (const uint32_t *)bq3->qh + 4 *ib32;
784781 for (int i = 0 ; i < 4 ; ++i) {
785- auto values1 = iq3nl_values + (((bq3->extra [i+4 *is] >> ib32) & 1 ) << 3 );
782+ uint32_t extra32 = uint32_t ((bq3->extra [i+4 *is] >> ib32) & 1 ) * 0x88888888 ;
783+
786784 int sumi1 = 0 ;
787- int h = qh[i] >> 4 *is;
788- aux32[0 ] = ((q2[i] >> 0 ) & 0x03030303 ) | ((h << 2 ) & 0x04040404 );
789- aux32[1 ] = ((q2[i] >> 2 ) & 0x03030303 ) | ((h << 1 ) & 0x04040404 );
790- val1.x = int_from_table (aux8+0 , (const uint8_t *)values1);
791- val1.y = int_from_table (aux8+4 , (const uint8_t *)values1);
792- sumi1 = ggml_cuda_dp4a (val1.x , q8[0 ], ggml_cuda_dp4a (val1.y , q8[1 ], sumi1));
793- aux32[0 ] = ((q2[i] >> 4 ) & 0x03030303 ) | ((h >> 0 ) & 0x04040404 );
794- aux32[1 ] = ((q2[i] >> 6 ) & 0x03030303 ) | ((h >> 1 ) & 0x04040404 );
795- val1.x = int_from_table (aux8+0 , (const uint8_t *)values1);
796- val1.y = int_from_table (aux8+4 , (const uint8_t *)values1);
797- sumi1 = ggml_cuda_dp4a (val1.x , q8[2 ], ggml_cuda_dp4a (val1.y , q8[3 ], sumi1));
785+ uint32_t h = qh[i] >> 4 *is;
786+ uint32_t val1 = ((q2[i] >> 0 ) & 0x33333333 ) | extra32 | ((h << 2 ) & 0x04040404 ) | ((h << 4 ) & 0x40404040 );
787+ uint32_t val2 = ((q2[i] >> 2 ) & 0x33333333 ) | extra32 | ((h << 1 ) & 0x04040404 ) | ((h << 3 ) & 0x40404040 );
788+ int2 v1 = get_int_from_table_16 (val1, iq3nl_values);
789+ int2 v2 = get_int_from_table_16 (val2, iq3nl_values);
790+ sumi1 = ggml_cuda_dp4a (v1.x , q8[0 ], ggml_cuda_dp4a (v2.x , q8[1 ], sumi1));
791+ sumi1 = ggml_cuda_dp4a (v1.y , q8[2 ], ggml_cuda_dp4a (v2.y , q8[3 ], sumi1));
798792 const float d = __half2float (bq3->d [i]) * d8;
799793 result[i] += d * sumi1 * s8[i] * (s8[i+4 ] ? -1 : 1 );
800- // result[i] += d * sumi1 * s8[i];
801794 }
802795}
803796
@@ -1021,41 +1014,32 @@ __device__ __forceinline__ void vec_dot_iq3_k_q8_1(
10211014 const uint16_t sh = bq3->scales_h >> (8 *ib128 + il8/2 );
10221015
10231016 const uint8_t extra = bq3->extra >> (8 *ib128 + il8/2 );
1024- const uint16_t * values1 = iq3k_table + ((extra << 6 ) & 0x40 );
1025- const uint16_t * values2 = iq3k_table + ((extra << 5 ) & 0x40 );
1026- const uint16_t * values3 = iq3k_table + ((extra << 4 ) & 0x40 );
1027- const uint16_t * values4 = iq3k_table + ((extra << 3 ) & 0x40 );
1017+ uint32_t extra32 = uint32_t (extra) * 0x01010101 ;
1018+ uint32_t extra32_1 = ((extra32 << 3 ) & 0x08080808 ) | ((extra32 << 5 ) & 0x80808080 );
1019+ uint32_t extra32_2 = ((extra32 << 2 ) & 0x08080808 ) | ((extra32 << 4 ) & 0x80808080 );
10281020
10291021 const int * q8;
10301022 int sumi[4 ] = {0 , 0 , 0 , 0 };
1031- int v;
10321023 for (int i = 0 ; i < 2 ; ++i) {
10331024 uint32_t vl = ql[2 *i+0 ] | (ql[2 *i+1 ] << 16 );
1034- uint32_t vh = ((qh[2 *i+0 ] | (qh[2 *i+1 ] << 16 )) << hshift) >> 2 ;
1025+ uint32_t vh = ((qh[2 *i+0 ] | (qh[2 *i+1 ] << 16 )) << hshift);
1026+
1027+ uint32_t val1 = ((vl >> 0 ) & 0x33333333 ) | extra32_1 | ((vh >> 2 ) & 0x04040404 ) | ((vh >> 0 ) & 0x40404040 );
1028+ uint32_t val2 = ((vl >> 2 ) & 0x33333333 ) | extra32_2 | ((vh >> 3 ) & 0x04040404 ) | ((vh >> 1 ) & 0x40404040 );
1029+ int2 v1 = get_int_from_table_16 (val1, iq3nl_values);
1030+ int2 v2 = get_int_from_table_16 (val2, iq3nl_values);
10351031
10361032 q8 = (const int *)bq8_1[4 *ib128+0 ].qs + 2 *il8;
1037- aux32 = (vl & 0x03030303 ) | (vh & 0x04040404 );
1038- v = int_from_table_2 (aux8, values1);
1039- sumi[0 ] = ggml_cuda_dp4a (v, q8[i], sumi[0 ]);
1040- vl >>= 2 ; vh >>= 1 ;
1033+ sumi[0 ] = ggml_cuda_dp4a (v1.x , q8[i], sumi[0 ]);
10411034
10421035 q8 += sizeof (block_q8_1)/4 ;
1043- aux32 = (vl & 0x03030303 ) | (vh & 0x04040404 );
1044- v = int_from_table_2 (aux8, values2);
1045- sumi[1 ] = ggml_cuda_dp4a (v, q8[i], sumi[1 ]);
1046- vl >>= 2 ; vh >>= 1 ;
1036+ sumi[1 ] = ggml_cuda_dp4a (v2.x , q8[i], sumi[1 ]);
10471037
10481038 q8 += sizeof (block_q8_1)/4 ;
1049- aux32 = (vl & 0x03030303 ) | (vh & 0x04040404 );
1050- v = int_from_table_2 (aux8, values3);
1051- sumi[2 ] = ggml_cuda_dp4a (v, q8[i], sumi[2 ]);
1052- vl >>= 2 ; vh >>= 1 ;
1039+ sumi[2 ] = ggml_cuda_dp4a (v1.y , q8[i], sumi[2 ]);
10531040
10541041 q8 += sizeof (block_q8_1)/4 ;
1055- aux32 = (vl & 0x03030303 ) | (vh & 0x04040404 );
1056- v = int_from_table_2 (aux8, values4);
1057- sumi[3 ] = ggml_cuda_dp4a (v, q8[i], sumi[3 ]);
1058-
1042+ sumi[3 ] = ggml_cuda_dp4a (v2.y , q8[i], sumi[3 ]);
10591043 }
10601044 const float d = __half2float (bq3->d );
10611045 const uint16_t * sl16 = (const uint16_t *)bq3->scales_l + 2 *ib128;
@@ -1127,50 +1111,37 @@ __device__ __forceinline__ void vec_dot_iq3_ks_q8_1(
11271111 const uint16_t * ql = (const uint16_t *)bq3->qs + 16 *ib128 + 4 *il8;
11281112 const uint16_t * qh = (const uint16_t *)bq3->qh + 4 *il8;
11291113
1130- int32_t aux32;
1131- const uint8_t * aux8 = (const uint8_t *)&aux32;
1132-
11331114 uint16_t extra = bq3->extra >> 4 *ib128;
1134- uint16_t extra_v = extra >> 8 ;
1115+ uint32_t extra_v = uint32_t ( extra >> 8 ) * 0x01010101 ;
11351116
1136- const uint16_t * values1 = iq3k_table + ((extra_v << 6 ) & 0x40 );
1137- const uint16_t * values2 = iq3k_table + ((extra_v << 5 ) & 0x40 );
1138- const uint16_t * values3 = iq3k_table + ((extra_v << 4 ) & 0x40 );
1139- const uint16_t * values4 = iq3k_table + ((extra_v << 3 ) & 0x40 );
1117+ uint32_t extra32_1 = ((extra_v << 3 ) & 0x08080808 ) | ((extra_v << 5 ) & 0x80808080 );
1118+ uint32_t extra32_2 = ((extra_v << 2 ) & 0x08080808 ) | ((extra_v << 4 ) & 0x80808080 );
11401119
11411120 const int * q8;
11421121 int sumi[4 ] = {0 , 0 , 0 , 0 };
1143- int v;
11441122 for (int i = 0 ; i < 2 ; ++i) {
11451123 uint32_t vl = ql[2 *i+0 ] | (ql[2 *i+1 ] << 16 );
1146- uint32_t vh = ((qh[2 *i+0 ] | (qh[2 *i+1 ] << 16 )) >> 4 *ib128) << 2 ;
1124+ uint32_t vh = ((qh[2 *i+0 ] | (qh[2 *i+1 ] << 16 )) >> 4 *ib128);
1125+
1126+ uint32_t val1 = ((vl >> 0 ) & 0x33333333 ) | extra32_1 | ((vh << 2 ) & 0x04040404 ) | ((vh << 4 ) & 0x40404040 );
1127+ uint32_t val2 = ((vl >> 2 ) & 0x33333333 ) | extra32_2 | ((vh << 1 ) & 0x04040404 ) | ((vh << 3 ) & 0x40404040 );
1128+ int2 v1 = get_int_from_table_16 (val1, iq3nl_values);
1129+ int2 v2 = get_int_from_table_16 (val2, iq3nl_values);
11471130
11481131 q8 = (const int *)bq8_1[4 *ib128+0 ].qs + 2 *il8;
1149- aux32 = (vl & 0x03030303 ) | (vh & 0x04040404 );
1150- v = int_from_table_2 (aux8, values1);
1151- sumi[0 ] = ggml_cuda_dp4a (v, q8[i], sumi[0 ]);
1152- vl >>= 2 ; vh >>= 1 ;
1132+ sumi[0 ] = ggml_cuda_dp4a (v1.x , q8[i], sumi[0 ]);
11531133
11541134 q8 += sizeof (block_q8_1)/4 ;
1155- aux32 = (vl & 0x03030303 ) | (vh & 0x04040404 );
1156- v = int_from_table_2 (aux8, values2);
1157- sumi[1 ] = ggml_cuda_dp4a (v, q8[i], sumi[1 ]);
1158- vl >>= 2 ; vh >>= 1 ;
1135+ sumi[1 ] = ggml_cuda_dp4a (v2.x , q8[i], sumi[1 ]);
11591136
11601137 q8 += sizeof (block_q8_1)/4 ;
1161- aux32 = (vl & 0x03030303 ) | (vh & 0x04040404 );
1162- v = int_from_table_2 (aux8, values3);
1163- sumi[2 ] = ggml_cuda_dp4a (v, q8[i], sumi[2 ]);
1164- vl >>= 2 ; vh >>= 1 ;
1138+ sumi[2 ] = ggml_cuda_dp4a (v1.y , q8[i], sumi[2 ]);
11651139
11661140 q8 += sizeof (block_q8_1)/4 ;
1167- aux32 = (vl & 0x03030303 ) | (vh & 0x04040404 );
1168- v = int_from_table_2 (aux8, values4);
1169- sumi[3 ] = ggml_cuda_dp4a (v, q8[i], sumi[3 ]);
1170-
1141+ sumi[3 ] = ggml_cuda_dp4a (v2.y , q8[i], sumi[3 ]);
11711142 }
11721143 const uint16_t * sl16 = (const uint16_t *)bq3->scales ;
1173- aux32 = __vsub4 (((sl16[0 ] | (sl16[1 ] << 16 )) >> 4 *ib128) & 0x0f0f0f0f , 0x10101010 );
1144+ int32_t aux32 = __vsub4 (((sl16[0 ] | (sl16[1 ] << 16 )) >> 4 *ib128) & 0x0f0f0f0f , 0x10101010 );
11741145 const int8_t * a8 = (const int8_t *)&aux32;
11751146 *result += d * (__low2float (bq8_1[4 *ib128+0 ].ds ) * (a8[0 ] + ((extra << 4 ) & 0x10 )) * sumi[0 ] +
11761147 __low2float (bq8_1[4 *ib128+1 ].ds ) * (a8[1 ] + ((extra << 3 ) & 0x10 )) * sumi[1 ] +
0 commit comments