@@ -543,11 +543,14 @@ __device__ __forceinline__ float vec_dot_iq2_ks_q8_1(
543543 const uint8_t * a8 = (const uint8_t *)&aux32;
544544 int v1, v2;
545545
546- int8_t s8[4 ];
547- s8[0 ] = ((bq2->scales [2 *(i4/4 )+0 ] & 0xf ) | ((extra >> 4 ) & 0x10 )) - 16 ;
548- s8[1 ] = ((bq2->scales [2 *(i4/4 )+0 ] >> 4 ) | ((extra >> 5 ) & 0x10 )) - 16 ;
549- s8[2 ] = ((bq2->scales [2 *(i4/4 )+1 ] & 0xf ) | ((extra >> 6 ) & 0x10 )) - 16 ;
550- s8[3 ] = ((bq2->scales [2 *(i4/4 )+1 ] >> 4 ) | ((extra >> 7 ) & 0x10 )) - 16 ;
546+ int32_t scales32;
547+ const uint16_t * scales16 = (const uint16_t *)bq2->scales ;
548+ scales32 = __vsub4 ((scales16[i4/4 ] | (scales16[i4/4 ] << 12 )) & 0x0f0f0f0f , 0x10101010 );
549+ int8_t * s8 = (int8_t *)&scales32;
550+ s8[0 ] += ((extra >> 4 ) & 0x10 );
551+ s8[1 ] += ((extra >> 6 ) & 0x10 );
552+ s8[2 ] += ((extra >> 5 ) & 0x10 );
553+ s8[3 ] += ((extra >> 7 ) & 0x10 );
551554
552555 aux32[0 ] = ((val1 >> 0 ) & 0x03030303 ); aux32[1 ] = ((val2 >> 0 ) & 0x03030303 ); values = all_values + ((extra & 0x01 ) << 8 );
553556 v1 = int_from_table_4 (a8 + 0 , values);
@@ -557,12 +560,12 @@ __device__ __forceinline__ float vec_dot_iq2_ks_q8_1(
557560 aux32[0 ] = ((val1 >> 2 ) & 0x03030303 ); aux32[1 ] = ((val2 >> 2 ) & 0x03030303 ); values = all_values + ((extra & 0x02 ) << 7 );
558561 v1 = int_from_table_4 (a8 + 0 , values);
559562 v2 = int_from_table_4 (a8 + 4 , values);
560- int sumi2 = ggml_cuda_dp4a (v2, q8_2[1 ], ggml_cuda_dp4a (v1, q8_2[0 ], 0 )) * s8[1 ];
563+ int sumi2 = ggml_cuda_dp4a (v2, q8_2[1 ], ggml_cuda_dp4a (v1, q8_2[0 ], 0 )) * s8[2 ];
561564
562565 aux32[0 ] = ((val1 >> 4 ) & 0x03030303 ); aux32[1 ] = ((val2 >> 4 ) & 0x03030303 ); values = all_values + ((extra & 0x04 ) << 6 );
563566 v1 = int_from_table_4 (a8 + 0 , values);
564567 v2 = int_from_table_4 (a8 + 4 , values);
565- int sumi3 = ggml_cuda_dp4a (v2, q8_3[1 ], ggml_cuda_dp4a (v1, q8_3[0 ], 0 )) * s8[2 ];
568+ int sumi3 = ggml_cuda_dp4a (v2, q8_3[1 ], ggml_cuda_dp4a (v1, q8_3[0 ], 0 )) * s8[1 ];
566569
567570 aux32[0 ] = ((val1 >> 6 ) & 0x03030303 ); aux32[1 ] = ((val2 >> 6 ) & 0x03030303 ); values = all_values + ((extra & 0x08 ) << 5 );
568571 v1 = int_from_table_4 (a8 + 0 , values);
0 commit comments