@@ -140,9 +140,6 @@ static constexpr __device__ int get_mmq_y_device() {
140140#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
141141}
142142
143- // tile_x_sizes{qs, dm, sc}
144-
145- // TODO: TQ2_0 to minimize shared mem
146143#define MMQ_DP4A_TXS_Q4_0 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_0 + mmq_y/QI4_0, 0 }
147144#define MMQ_DP4A_TXS_Q4_1 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_1 + mmq_y/QI4_1, 0 }
148145#define MMQ_DP4A_TXS_Q8_0 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE*2 /QI8_0 + mmq_y/(QI8_0/2 ), 0 }
@@ -1814,7 +1811,6 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
18141811#endif // INT8_MMA_AVAILABLE
18151812}
18161813
1817- // This is the first "simple" type with a block size of 256
18181814template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_tq2_0 (
18191815 const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
18201816
@@ -1840,22 +1836,22 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
18401836 const block_tq2_0 * bxi = (const block_tq2_0 *) x + kbx0 + i*stride;
18411837 const int qs0 = get_int_b2 (bxi->qs , kqsx);
18421838
1843- #ifdef INT8_MMA_AVAILABLE
1844-
18451839#pragma unroll
18461840 for (int l = 0 ; l < QR2_0; ++l) {
18471841 // 0..7, 32..39
18481842 // 8..15, 40..47
18491843 // 16..23, 48..55
18501844 // 24..31, 56..63
1851- // FIXME: this might assume WARP_SIZE is >= 32
18521845 const int k = (kqsx/8 )*32 + l*8 + kqsx % 8 ;
1846+ const int q = __vsub4 ((qs0 >> (2 *l)) & 0x03030303 , 0x01010101 );
18531847
1854- x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k] = __vsub4 ((qs0 >> ( 2 *l)) & 0x03030303 , 0x01010101 );
1855- }
1848+ # ifdef INT8_MMA_AVAILABLE
1849+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k] = q;
18561850#else
1857- x_qs[i*(2 *WARP_SIZE + 1 ) + kqsx] = qs0;
1851+ // NOTE: this might assume WARP_SIZE is >= 32
1852+ x_qs[i*(2 *WARP_SIZE + 1 ) + k] = q;
18581853#endif // INT8_MMA_AVAILABLE
1854+ }
18591855 }
18601856
18611857 // TODO: does this work with WARP_SIZE != 32?
@@ -1872,45 +1868,13 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
18721868 const int k = threadIdx .x % (QI2_0/2 );
18731869
18741870#ifdef INT8_MMA_AVAILABLE
1875-
18761871 x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + k] = bxi->d ;
18771872#else
18781873 x_df[i*(WARP_SIZE/4 ) + i/4 + k] = bxi->d ;
18791874#endif // INT8_MMA_AVAILABLE
18801875 }
18811876}
18821877
1883- template <int mmq_x, int mmq_y, int nwarps>
1884- static __device__ __forceinline__ void vec_dot_tq2_0_q8_1_dp4a (
1885- const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
1886-
1887- constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes (GGML_TYPE_TQ2_0, mmq_y);
1888- const int * x_qs = (const int *) x;
1889- const float * x_df = (const float *) x_qs + txs.qs ;
1890- const int * y_qs = (const int *) y + 4 ;
1891- const float * y_df = (const float *) y;
1892-
1893- #pragma unroll
1894- for (int k01 = 0 ; k01 < WARP_SIZE; k01 += QR2_0*VDR_TQ2_0_Q8_1_MMQ) {
1895- const int k0 = k00 + k01;
1896-
1897- #pragma unroll
1898- for (int j0 = 0 ; j0 < mmq_x; j0 += nwarps) {
1899- const int j = j0 + threadIdx .y ;
1900-
1901- #pragma unroll
1902- for (int i0 = 0 ; i0 < mmq_y; i0 += WARP_SIZE) {
1903- const int i = i0 + threadIdx .x ;
1904-
1905- sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_tq2_0_q8_1_impl<VDR_TQ2_0_Q8_1_MMQ>(
1906- &x_qs[i*(2 *WARP_SIZE + 1 ) + k0/QR2_0], &y_qs[j*MMQ_TILE_Y_K + k01],
1907- x_df[i*(2 *WARP_SIZE/QI8_0) + i/(QI8_0/2 )], &y_df[j*MMQ_TILE_Y_K + k01/QI8_1]);
1908- // x_df[i*(WARP_SIZE/QI2_0) + i/QI2_0], &y_df[j*MMQ_TILE_Y_K + k01/QI8_1]);
1909- }
1910- }
1911- }
1912- }
1913-
19141878template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_nl (
19151879 const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
19161880
@@ -2535,7 +2499,7 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_TQ2_0> {
25352499 static constexpr int vdr = VDR_TQ2_0_Q8_1_MMQ;
25362500 static constexpr load_tiles_mmq_t load_tiles = load_tiles_tq2_0<mmq_y, nwarps, need_check>;
25372501 static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>;
2538- static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_tq2_0_q8_1_dp4a <mmq_x, mmq_y, nwarps>;
2502+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a <mmq_x, mmq_y, nwarps>;
25392503};
25402504
25412505template <int mmq_x, int mmq_y, int nwarps, bool need_check>
0 commit comments