@@ -1956,25 +1956,14 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a(
19561956template <int mmq_x, int mmq_y>
19571957static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma (
19581958 const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
1959- constexpr int nwarps = get_mmq_nwarps_device (GGML_TYPE_Q6_K);
1959+ # if defined(NEW_MMA_AVAILABLE)
19601960
1961- #if defined(AMD_MMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
1962- #if defined(AMD_MMA_AVAILABLE)
1963- typedef tile<32 , 4 , int > tile_A;
1964- typedef tile<32 , 4 , int > tile_B;
1965- typedef tile<32 , 32 , int > tile_C;
1966- #else
19671961 typedef tile<16 , 4 , int > tile_A;
19681962 typedef tile< 8 , 4 , int > tile_B;
19691963 typedef tile<16 , 8 , int > tile_C;
1970- #endif
19711964
19721965 constexpr int granularity = mmq_get_granularity_device (GGML_TYPE_Q6_K, mmq_x);
1973- #if defined(AMD_MMA_AVAILABLE)
1974- constexpr int rows_per_warp = granularity; // 32
1975- #else
19761966 constexpr int rows_per_warp = 2 * granularity;
1977- #endif
19781967 constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
19791968
19801969 y += (threadIdx .y % ntx) * (tile_B::I*MMQ_TILE_Y_K);
@@ -1988,13 +1977,8 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
19881977 const int i0 = (threadIdx .y / ntx) * (ntx*tile_A::I);
19891978
19901979 tile_A A[ntx][8 ];
1991- #if defined(AMD_MMA_AVAILABLE)
1992- int scA[ntx][tile_C::ne][8 ];
1993- float dA[ntx][tile_C::ne];
1994- #else
19951980 int scA[ntx][tile_C::ne/2 ][8 ];
19961981 float dA[ntx][tile_C::ne/2 ];
1997- #endif
19981982
19991983#pragma unroll
20001984 for (int n = 0 ; n < ntx; ++n) {
@@ -2011,13 +1995,9 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
20111995 const int k0 = k00 + k01;
20121996
20131997#pragma unroll
2014- #if defined(AMD_MMA_AVAILABLE)
2015- for (int l = 0 ; l < tile_C::ne; ++l) {
2016- const int i = i0 + n*tile_C::I + tile_C::get_i (l);
2017- #else
20181998 for (int l = 0 ; l < tile_C::ne/2 ; ++l) {
20191999 const int i = i0 + n*tile_C::I + tile_C::get_i (2 *l);
2020- #endif
2000+
20212001 const int sc_packed = x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + k0/16 ];
20222002 const int8_t * sc = (const int8_t *) &sc_packed;
20232003
@@ -2029,13 +2009,9 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
20292009 }
20302010
20312011#pragma unroll
2032- #if defined(AMD_MMA_AVAILABLE)
2033- for (int l = 0 ; l < tile_C::ne; ++l) {
2034- const int i = i0 + n*tile_C::I + tile_C::get_i (l);
2035- #else
20362012 for (int l = 0 ; l < tile_C::ne/2 ; ++l) {
20372013 const int i = i0 + n*tile_C::I + tile_C::get_i (2 *l);
2038- #endif
2014+
20392015 dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q6_K];
20402016 }
20412017 }
@@ -2047,29 +2023,18 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
20472023#pragma unroll
20482024 for (int k01 = 0 ; k01 < MMQ_TILE_NE_K; k01 += 8 ) {
20492025 tile_B B[2 ];
2026+ float dB[tile_C::ne/2 ];
20502027
2051- #if defined(AMD_MMA_AVAILABLE)
2052- load_ldmatrix (B[0 ], y_qs + j0*MMQ_TILE_Y_K + 0 + k01, MMQ_TILE_Y_K);
2053- load_ldmatrix (B[1 ], y_qs + j0*MMQ_TILE_Y_K + tile_B::J + k01, MMQ_TILE_Y_K);
2054- #else
20552028 // Here load_generic is faster than load_ldmatrix.
20562029 load_generic (B[0 ], y_qs + j0*MMQ_TILE_Y_K + 0 + k01, MMQ_TILE_Y_K);
20572030 load_generic (B[1 ], y_qs + j0*MMQ_TILE_Y_K + tile_B::J + k01, MMQ_TILE_Y_K);
2058- #endif
20592031
2060- #if defined(AMD_MMA_AVAILABLE)
2061- float dB;
2062- const int j = j0 + tile_C::get_j (0 );
2063- dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
2064- #else
2065- float dB[tile_C::ne/2 ];
20662032#pragma unroll
20672033 for (int l = 0 ; l < tile_C::ne/2 ; ++l) {
20682034 const int j = j0 + tile_C::get_j (l);
20692035
20702036 dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
20712037 }
2072- #endif
20732038
20742039#pragma unroll
20752040 for (int n = 0 ; n < ntx; ++n) {
@@ -2079,11 +2044,7 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
20792044
20802045#pragma unroll
20812046 for (int l = 0 ; l < tile_C::ne; ++l) {
2082- #if defined(AMD_MMA_AVAILABLE)
2083- tmp[n][l] += (C[0 ].x [l]*scA[n][l][k01/4 + 0 ] + C[1 ].x [l]*scA[n][l][k01/4 + 1 ])*dB;
2084- #else
20852047 tmp[n][l] += (C[0 ].x [l]*scA[n][l/2 ][k01/4 + 0 ] + C[1 ].x [l]*scA[n][l/2 ][k01/4 + 1 ])*dB[l%2 ];
2086- #endif
20872048 }
20882049 }
20892050 }
@@ -2092,11 +2053,55 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
20922053 for (int n = 0 ; n < ntx; ++n) {
20932054#pragma unroll
20942055 for (int l = 0 ; l < tile_C::ne; ++l) {
2095- #if defined(AMD_MMA_AVAILABLE)
2096- sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp[n][l]*dA[n][l];
2097- #else
20982056 sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp[n][l]*dA[n][l/2 ];
2099- #endif
2057+ }
2058+ }
2059+ }
2060+ #elif defined(AMD_MMA_AVAILABLE)
2061+ typedef tile<32 , 4 , int > tile_A;
2062+ typedef tile<32 , 4 , int > tile_B;
2063+ typedef tile<32 , 32 , int > tile_C;
2064+
2065+ const int * x_qs = (const int *) x;
2066+ const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2 ;
2067+ const int * x_sc = (const int *) x_df + MMQ_TILE_NE_K/QI6_K;
2068+ const int * y_qs = (const int *) y + 4 ;
2069+ const float * y_df = (const float *) y;
2070+
2071+ const int i0 = threadIdx .y * tile_A::I;
2072+
2073+ int scA[tile_C::ne][2 ];
2074+ float dA[tile_C::ne];
2075+
2076+ #pragma unroll
2077+ for (int l = 0 ; l < tile_C::ne; ++l) {
2078+ const int i = i0 + tile_C::get_i (l);
2079+ scA[l][0 ] = x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + k00/16 + 0 ];
2080+ scA[l][1 ] = x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + k00/16 + 1 ];
2081+ dA[l] = x_df[i*MMQ_MMA_TILE_X_K_Q6_K];
2082+ }
2083+
2084+ for (int k01 = 0 ; k01 < MMQ_TILE_NE_K; k01 += 4 ) {
2085+ const int k0 = k00 + k01;
2086+
2087+ tile_A A;
2088+ load_ldmatrix (A, x_qs + i0*MMQ_MMA_TILE_X_K_Q6_K + k0, MMQ_MMA_TILE_X_K_Q6_K);
2089+
2090+ #pragma unroll
2091+ for (int j0 = 0 ; j0 < mmq_x; j0 += tile_C::J) {
2092+ tile_B B;
2093+ load_ldmatrix (B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
2094+
2095+ float dB;
2096+ const int j = j0 + tile_C::get_j (0 );
2097+ dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
2098+
2099+ tile_C C;
2100+ mma (C, A, B);
2101+
2102+ for (int l = 0 ; l < tile_C::ne; ++l) {
2103+ const int8_t * sc = (const int8_t *) scA[l];
2104+ sum[(j0/tile_C::J)*tile_C::ne + l] += C.x [l] * sc[k01/4 ] * dA[l] * dB;
21002105 }
21012106 }
21022107 }
@@ -2883,7 +2888,9 @@ static __device__ __forceinline__ void mul_mat_q_process_tile(
28832888
28842889template <ggml_type type, int mmq_x, int warp_size, bool need_check>
28852890#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
2886- #if defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
2891+ #if defined(AMD_MMA_AVAILABLE)
2892+ __launch_bounds__ (warp_size*get_mmq_nwarps_device (type), 1 )
2893+ #elif defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(CDNA2) || defined(CDNA1) || defined(GCN)
28872894 __launch_bounds__ (warp_size*get_mmq_nwarps_device (type), 2 )
28882895#endif // defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
28892896#else
0 commit comments