@@ -66,6 +66,7 @@ static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) {
6666 case GGML_TYPE_Q5_K:
6767 return MMQ_Q8_1_DS_LAYOUT_DS4;
6868 case GGML_TYPE_Q6_K:
69+ case GGML_TYPE_TQ2_0:
6970 case GGML_TYPE_IQ2_XXS:
7071 case GGML_TYPE_IQ2_XS:
7172 case GGML_TYPE_IQ2_S:
@@ -165,6 +166,7 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml
165166 type == GGML_TYPE_Q4_K ? MMQ_DP4A_TXS_Q4_K :
166167 type == GGML_TYPE_Q5_K ? MMQ_DP4A_TXS_Q5_K :
167168 type == GGML_TYPE_Q6_K ? MMQ_DP4A_TXS_Q6_K :
169+ type == GGML_TYPE_TQ2_0 ? MMQ_DP4A_TXS_Q8_0 :
168170 type == GGML_TYPE_IQ2_XXS ? MMQ_DP4A_TXS_Q8_0 :
169171 type == GGML_TYPE_IQ2_XS ? MMQ_DP4A_TXS_Q8_0_16 :
170172 type == GGML_TYPE_IQ2_S ? MMQ_DP4A_TXS_Q8_0_16 :
@@ -200,6 +202,7 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
200202 type == GGML_TYPE_Q4_K ? MMQ_MMA_TILE_X_K_Q8_1 :
201203 type == GGML_TYPE_Q5_K ? MMQ_MMA_TILE_X_K_Q8_1 :
202204 type == GGML_TYPE_Q6_K ? MMQ_MMA_TILE_X_K_Q6_K :
205+ type == GGML_TYPE_TQ2_0 ? MMQ_MMA_TILE_X_K_Q8_0 :
203206 type == GGML_TYPE_IQ2_XXS ? MMQ_MMA_TILE_X_K_Q8_0 :
204207 type == GGML_TYPE_IQ2_XS ? MMQ_MMA_TILE_X_K_Q3_K :
205208 type == GGML_TYPE_IQ2_S ? MMQ_MMA_TILE_X_K_Q3_K :
@@ -1876,6 +1879,68 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
18761879#endif // INT8_MMA_AVAILABLE
18771880}
18781881
1882+ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_tq2_0 (
1883+ const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
1884+
1885+ #ifdef INT8_MMA_AVAILABLE
1886+ int * x_qs = (int *) x_tile;
1887+ float * x_df = (float *) (x_tile + 2 *WARP_SIZE);
1888+ #else
1889+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes (GGML_TYPE_TQ2_0, mmq_y);
1890+ int * x_qs = (int *) x_tile;
1891+ float * x_df = (float *) (x_qs + txs.qs );
1892+ #endif // INT8_MMA_AVAILABLE
1893+
1894+ const int kqsx = threadIdx .x % QI2_0;
1895+
1896+ #pragma unroll
1897+ for (int i0 = 0 ; i0 < mmq_y; i0 += nwarps * WARP_SIZE/QI2_0) {
1898+ int i = i0 + threadIdx .y *(WARP_SIZE/QI2_0) + threadIdx .x /QI2_0;
1899+
1900+ if (need_check) {
1901+ i = min (i, i_max);
1902+ }
1903+
1904+ const block_tq2_0 * bxi = (const block_tq2_0 *) x + kbx0 + i*stride;
1905+ const int qs0 = get_int_b2 (bxi->qs , kqsx);
1906+
1907+ #pragma unroll
1908+ for (int l = 0 ; l < QR2_0; ++l) {
1909+ // 0..7, 32..39
1910+ // 8..15, 40..47
1911+ // 16..23, 48..55
1912+ // 24..31, 56..63
1913+ const int k = (kqsx/8 )*32 + l*8 + kqsx % 8 ;
1914+ const int q = __vsub4 ((qs0 >> (2 *l)) & 0x03030303 , 0x01010101 );
1915+
1916+ #ifdef INT8_MMA_AVAILABLE
1917+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k] = q;
1918+ #else
1919+ x_qs[i*(2 *WARP_SIZE + 1 ) + k] = q;
1920+ #endif // INT8_MMA_AVAILABLE
1921+ }
1922+ }
1923+
1924+ #pragma unroll
1925+ for (int i0 = 0 ; i0 < mmq_y; i0 += nwarps * WARP_SIZE/(QI2_0/2 )) {
1926+ int i = i0 + threadIdx .y *(2 *WARP_SIZE/QI2_0) + threadIdx .x /(QI2_0/2 );
1927+
1928+ if (need_check) {
1929+ i = min (i, i_max);
1930+ }
1931+
1932+ const block_tq2_0 * bxi = (const block_tq2_0 *) x + kbx0 + i*stride;
1933+
1934+ const int k = threadIdx .x % (QI2_0/2 );
1935+
1936+ #ifdef INT8_MMA_AVAILABLE
1937+ x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + k] = bxi->d ;
1938+ #else
1939+ x_df[i*(WARP_SIZE/4 ) + i/4 + k] = bxi->d ;
1940+ #endif // INT8_MMA_AVAILABLE
1941+ }
1942+ }
1943+
18791944template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_nl (
18801945 const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
18811946
@@ -2503,6 +2568,14 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q6_K> {
25032568 static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q6_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
25042569};
25052570
2571+ template <int mmq_x, int mmq_y, int nwarps, bool need_check>
2572+ struct mmq_type_traits <mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_TQ2_0> {
2573+ static constexpr int vdr = VDR_TQ2_0_Q8_1_MMQ;
2574+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_tq2_0<mmq_y, nwarps, need_check>;
2575+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>;
2576+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
2577+ };
2578+
25062579template <int mmq_x, int mmq_y, int nwarps, bool need_check>
25072580struct mmq_type_traits <mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ2_XXS> {
25082581 static constexpr int vdr = VDR_IQ2_XXS_Q8_1_MMQ;
@@ -2993,6 +3066,7 @@ extern DECL_MMQ_CASE(GGML_TYPE_Q3_K);
29933066extern DECL_MMQ_CASE (GGML_TYPE_Q4_K);
29943067extern DECL_MMQ_CASE (GGML_TYPE_Q5_K);
29953068extern DECL_MMQ_CASE (GGML_TYPE_Q6_K);
3069+ extern DECL_MMQ_CASE (GGML_TYPE_TQ2_0);
29963070extern DECL_MMQ_CASE (GGML_TYPE_IQ2_XXS);
29973071extern DECL_MMQ_CASE (GGML_TYPE_IQ2_XS);
29983072extern DECL_MMQ_CASE (GGML_TYPE_IQ2_S);
0 commit comments