@@ -683,37 +683,6 @@ static __global__ void dequantize_block_iq4_xs(const void * __restrict__ vx, dst
683683 }
684684}
685685
686- template <typename dst_t >
687- static __global__ void dequantize_block_iq3_ks_v1 (const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t n_per_row, int64_t row_size) {
688-
689- int64_t ii = blockIdx .x ;
690- int64_t row = (QK_K * ii) / n_per_row;
691- const char * cx = (const char *)vx + row * row_size;
692- float scale = *(const float *)cx;
693- const block_iq3_ks_v1 * x = (const block_iq3_ks_v1 *)(cx + sizeof (float ));
694- const int64_t i = ii - (row*n_per_row)/QK_K;
695-
696- const int tid = threadIdx .x ;
697- int ib128 = tid/16 ; // 0 or 1
698- int il = tid%16 ; // 0...15
699- dst_t * y = yy + ii*QK_K + 128 *ib128 + 2 *il;
700- // uint32_t sc = ((const uint32_t *)x[i].scales)[ib128];
701- // uint32_t aux32 =
702- const float dl1 = scale * ((x[i].scales [4 *ib128+0 ] & 254 ) - 127 );
703- const float dl2 = scale * ((x[i].scales [4 *ib128+1 ] & 254 ) - 127 );
704- const float dl3 = scale * ((x[i].scales [4 *ib128+2 ] & 254 ) - 127 );
705- const float dl4 = scale * ((x[i].scales [4 *ib128+3 ] & 254 ) - 127 );
706- const uint8_t * qs = x[i].qs + 32 *ib128 + 2 *il;
707- const uint8_t * qh = x[i].qh + 2 *il;
708- for (int j = 0 ; j < 2 ; ++j) {
709- const uint8_t h = qh[j] >> (4 *(ib128%2 ));
710- y[j+ 0 ] = dl1 * iq3nl_values[(((qs[j] >> 0 ) & 0x03 ) | ((h & 0x01 ) << 2 )) + ((x[i].scales [4 *ib128+0 ] & 1 ) << 3 )];
711- y[j+32 ] = dl2 * iq3nl_values[(((qs[j] >> 2 ) & 0x03 ) | ((h & 0x02 ) << 1 )) + ((x[i].scales [4 *ib128+1 ] & 1 ) << 3 )];
712- y[j+64 ] = dl3 * iq3nl_values[(((qs[j] >> 4 ) & 0x03 ) | ((h & 0x04 ) >> 0 )) + ((x[i].scales [4 *ib128+2 ] & 1 ) << 3 )];
713- y[j+96 ] = dl4 * iq3nl_values[(((qs[j] >> 6 ) & 0x03 ) | ((h & 0x08 ) >> 1 )) + ((x[i].scales [4 *ib128+3 ] & 1 ) << 3 )];
714- }
715- }
716-
717686template <typename dst_t >
718687static __global__ void dequantize_block_iq4_ks (const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t n_per_row, int64_t row_size) {
719688
@@ -1056,6 +1025,82 @@ static __global__ void dequantize_block_iq3_k(const void * __restrict__ vx, dst_
10561025 }
10571026}
10581027
1028+ template <typename dst_t >
1029+ static __global__ void dequantize_block_iq3_ks_v1 (const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t n_per_row, int64_t row_size) {
1030+
1031+ int64_t ii = blockIdx .x ;
1032+ int64_t row = (QK_K * ii) / n_per_row;
1033+ const char * cx = (const char *)vx + row * row_size;
1034+ float scale = *(const float *)cx;
1035+ const block_iq3_ks_v1 * x = (const block_iq3_ks_v1 *)(cx + sizeof (float ));
1036+ const int64_t i = ii - (row*n_per_row)/QK_K;
1037+
1038+ const int tid = threadIdx .x ;
1039+ int ib128 = tid/16 ; // 0 or 1
1040+ int il = tid%16 ; // 0...15
1041+ dst_t * y = yy + ii*QK_K + 128 *ib128 + 2 *il;
1042+ // uint32_t sc = ((const uint32_t *)x[i].scales)[ib128];
1043+ // uint32_t aux32 =
1044+ const float dl1 = scale * ((x[i].scales [4 *ib128+0 ] & 254 ) - 127 );
1045+ const float dl2 = scale * ((x[i].scales [4 *ib128+1 ] & 254 ) - 127 );
1046+ const float dl3 = scale * ((x[i].scales [4 *ib128+2 ] & 254 ) - 127 );
1047+ const float dl4 = scale * ((x[i].scales [4 *ib128+3 ] & 254 ) - 127 );
1048+ const uint8_t * qs = x[i].qs + 32 *ib128 + 2 *il;
1049+ const uint8_t * qh = x[i].qh + 2 *il;
1050+ for (int j = 0 ; j < 2 ; ++j) {
1051+ const uint8_t h = qh[j] >> (4 *(ib128%2 ));
1052+ y[j+ 0 ] = dl1 * iq3nl_values[(((qs[j] >> 0 ) & 0x03 ) | ((h & 0x01 ) << 2 )) + ((x[i].scales [4 *ib128+0 ] & 1 ) << 3 )];
1053+ y[j+32 ] = dl2 * iq3nl_values[(((qs[j] >> 2 ) & 0x03 ) | ((h & 0x02 ) << 1 )) + ((x[i].scales [4 *ib128+1 ] & 1 ) << 3 )];
1054+ y[j+64 ] = dl3 * iq3nl_values[(((qs[j] >> 4 ) & 0x03 ) | ((h & 0x04 ) >> 0 )) + ((x[i].scales [4 *ib128+2 ] & 1 ) << 3 )];
1055+ y[j+96 ] = dl4 * iq3nl_values[(((qs[j] >> 6 ) & 0x03 ) | ((h & 0x08 ) >> 1 )) + ((x[i].scales [4 *ib128+3 ] & 1 ) << 3 )];
1056+ }
1057+ }
1058+
1059+ template <typename dst_t >
1060+ static __global__ void dequantize_block_iq3_ks (const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t n_per_row, int64_t row_size) {
1061+
1062+ int64_t ii = blockIdx .x ;
1063+ int64_t row = (QK_K * ii) / n_per_row;
1064+ const char * cx = (const char *)vx + row * row_size;
1065+ float scale = *(const ggml_half *)cx;
1066+ const block_iq3_ks * x = (const block_iq3_ks *)(cx + sizeof (ggml_half));
1067+ const int64_t i = ii - (row*n_per_row)/QK_K;
1068+
1069+ const int64_t tid = threadIdx .x ;
1070+ const int64_t is = tid/16 ;
1071+ const int64_t il = tid%16 ;
1072+ dst_t * y = yy + ii*QK_K + 128 *is + 2 *il;
1073+ const uint8_t * qs = x[i].qs + 32 *is + 2 *il;
1074+ const uint8_t * qh = x[i].qh + 2 *il;
1075+ uint16_t extra = x[i].extra >> 4 *is;
1076+ const float d0 = scale * (int (((x[i].scales [0 ] >> 4 *is) & 0xf ) | ((extra << 4 ) & 0x10 )) - 16 );
1077+ const float d1 = scale * (int (((x[i].scales [1 ] >> 4 *is) & 0xf ) | ((extra << 3 ) & 0x10 )) - 16 );
1078+ const float d2 = scale * (int (((x[i].scales [2 ] >> 4 *is) & 0xf ) | ((extra << 2 ) & 0x10 )) - 16 );
1079+ const float d3 = scale * (int (((x[i].scales [3 ] >> 4 *is) & 0xf ) | ((extra << 1 ) & 0x10 )) - 16 );
1080+ extra >>= 8 ;
1081+ const int8_t * values0 = iq3nl_values + ((extra & 1 ) << 3 );
1082+ const int8_t * values1 = iq3nl_values + ((extra & 2 ) << 2 );
1083+ const int8_t * values2 = iq3nl_values + ((extra & 4 ) << 1 );
1084+ const int8_t * values3 = iq3nl_values + ((extra & 8 ) << 0 );
1085+ if constexpr (std::is_same_v<dst_t , nv_bfloat16>) {
1086+ for (int j = 0 ; j < 2 ; ++j) {
1087+ uint8_t h = qh[j] >> 4 *is;
1088+ y[j+ 0 ] = __float2bfloat16 (d0 * values0[((qs[j] >> 0 ) & 3 ) | ((h << 2 ) & 4 )]);
1089+ y[j+32 ] = __float2bfloat16 (d1 * values1[((qs[j] >> 2 ) & 3 ) | ((h << 1 ) & 4 )]);
1090+ y[j+64 ] = __float2bfloat16 (d2 * values2[((qs[j] >> 4 ) & 3 ) | ((h >> 0 ) & 4 )]);
1091+ y[j+96 ] = __float2bfloat16 (d3 * values3[((qs[j] >> 6 ) & 3 ) | ((h >> 1 ) & 4 )]);
1092+ }
1093+ } else {
1094+ for (int j = 0 ; j < 2 ; ++j) {
1095+ uint8_t h = qh[j] >> 4 *is;
1096+ y[j+ 0 ] = d0 * values0[((qs[j] >> 0 ) & 3 ) | ((h << 2 ) & 4 )];
1097+ y[j+32 ] = d1 * values1[((qs[j] >> 2 ) & 3 ) | ((h << 1 ) & 4 )];
1098+ y[j+64 ] = d2 * values2[((qs[j] >> 4 ) & 3 ) | ((h >> 0 ) & 4 )];
1099+ y[j+96 ] = d3 * values3[((qs[j] >> 6 ) & 3 ) | ((h >> 1 ) & 4 )];
1100+ }
1101+ }
1102+ }
1103+
10591104template <typename dst_t >
10601105static __global__ void dequantize_block_iq1_s_r4 (const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t n_per_row, int64_t row_size) {
10611106
@@ -1615,6 +1660,14 @@ static void dequantize_row_iq2_ks_cuda(const void * vx, dst_t * y, const int64_t
16151660 dequantize_block_iq2_ks<<<nb, 32 , 0 , stream>>> (vx, y, n_per_row, row_size);
16161661}
16171662
1663+ template <typename dst_t >
1664+ static void dequantize_row_iq2_k_cuda (const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
1665+ const int64_t k = nrows * n_per_row;
1666+ // const int64_t row_size = ggml_row_size(GGML_TYPE_IQ2_K, n_per_row);
1667+ const int nb = (k + QK_K - 1 ) / QK_K;
1668+ dequantize_block_iq2_k<<<nb, 32 , 0 , stream>>> (vx, y);
1669+ }
1670+
16181671template <typename dst_t >
16191672static void dequantize_row_iq3_ks_v1_cuda (const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
16201673 const int64_t k = nrows * n_per_row;
@@ -1624,11 +1677,11 @@ static void dequantize_row_iq3_ks_v1_cuda(const void * vx, dst_t * y, const int6
16241677}
16251678
16261679template <typename dst_t >
1627- static void dequantize_row_iq2_k_cuda (const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
1680+ static void dequantize_row_iq3_ks_cuda (const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
16281681 const int64_t k = nrows * n_per_row;
1629- // const int64_t row_size = ggml_row_size(GGML_TYPE_IQ2_K , n_per_row);
1682+ const int64_t row_size = ggml_row_size (GGML_TYPE_IQ3_KS , n_per_row);
16301683 const int nb = (k + QK_K - 1 ) / QK_K;
1631- dequantize_block_iq2_k <<<nb, 32 , 0 , stream>>> (vx, y);
1684+ dequantize_block_iq3_ks <<<nb, 32 , 0 , stream>>> (vx, y, n_per_row, row_size );
16321685}
16331686
16341687template <typename dst_t >
@@ -1816,10 +1869,12 @@ to_bf16_cuda_t ggml_get_to_bf16_cuda(ggml_type type) {
18161869 return dequantize_row_iq2_ks_cuda<nv_bfloat16>;
18171870 case GGML_TYPE_IQ2_K:
18181871 return dequantize_row_iq2_k_cuda<nv_bfloat16>;
1819- case GGML_TYPE_IQ3_K:
1820- return dequantize_row_iq3_k_cuda<nv_bfloat16>;
18211872 case GGML_TYPE_IQ3_KS_V1:
18221873 return dequantize_row_iq3_ks_v1_cuda<nv_bfloat16>;
1874+ case GGML_TYPE_IQ3_KS:
1875+ return dequantize_row_iq3_ks_cuda<nv_bfloat16>;
1876+ case GGML_TYPE_IQ3_K:
1877+ return dequantize_row_iq3_k_cuda<nv_bfloat16>;
18231878 case GGML_TYPE_IQ4_KSS:
18241879 return dequantize_row_iq4_kss_cuda<nv_bfloat16>;
18251880 case GGML_TYPE_IQ4_KS:
@@ -1918,6 +1973,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
19181973 return dequantize_row_iq3_ks_v1_cuda;
19191974 case GGML_TYPE_IQ2_K:
19201975 return dequantize_row_iq2_k_cuda;
1976+ case GGML_TYPE_IQ3_KS:
1977+ return dequantize_row_iq3_ks_cuda;
19211978 case GGML_TYPE_IQ3_K:
19221979 return dequantize_row_iq3_k_cuda;
19231980 case GGML_TYPE_IQ4_KSS:
@@ -2024,6 +2081,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
20242081 return dequantize_row_iq2_k_cuda;
20252082 case GGML_TYPE_IQ3_K:
20262083 return dequantize_row_iq3_k_cuda;
2084+ case GGML_TYPE_IQ3_KS:
2085+ return dequantize_row_iq3_ks_cuda;
20272086 case GGML_TYPE_IQ4_K:
20282087 return dequantize_row_iq4_k_cuda;
20292088 case GGML_TYPE_IQ5_K:
0 commit comments