@@ -891,23 +891,56 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_dp4a(
891891template <int mmq_x, int mmq_y>
892892static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma (
893893 const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
894- constexpr int nwarps = get_mmq_nwarps_device (GGML_TYPE_Q8_0);
895- // Tile definitions
896- typedef tile<16 , 8 , int > tile_A;
897894#if defined(AMD_MMA_AVAILABLE)
895+ typedef tile<16 , 8 , int > tile_A;
898896 typedef tile<16 , 8 , int > tile_B;
899897 typedef tile<16 , 16 , int > tile_C;
898+
899+ const int * x_qs = (const int *) x;
900+ const half2 * x_dm = (const half2 *) x_qs + 2 *MMQ_TILE_NE_K;
901+ const int * y_qs = (const int *) y + 4 ;
902+ const half2 * y_dm = (const half2 *) y;
903+
904+ const int i0 = threadIdx .y * tile_A::I;
905+
906+ for (int k01 = 0 ; k01 < MMQ_TILE_NE_K; k01 += QI8_1) {
907+ const int k0 = k00 + k01;
908+
909+ tile_A A;
910+ load_ldmatrix (A, x_qs + i0*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1);
911+
912+ float2 dmA[tile_C::ne];
913+ #pragma unroll
914+ for (int l = 0 ; l < tile_C::ne; ++l) {
915+ const int i = i0 + tile_C::get_i (l);
916+ dmA[l] = __half22float2 (x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + k0/QI8_1]);
917+ }
918+
919+ #pragma unroll
920+ for (int j0 = 0 ; j0 < mmq_x; j0 += tile_C::J) {
921+ tile_B B;
922+ load_ldmatrix (B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
923+
924+ float2 dsB;
925+ const int j = j0 + tile_C::get_j (0 );
926+ dsB = __half22float2 (y_dm[j*MMQ_TILE_Y_K + k01/QI8_1]);
927+
928+ tile_C C;
929+ mma (C, A, B);
930+
931+ for (int l = 0 ; l < tile_C::ne; ++l) {
932+ sum[(j0/tile_C::J)*tile_C::ne + l] += dmA[l].x *dsB.x *C.x [l];
933+ sum[(j0/tile_C::J)*tile_C::ne + l] += dmA[l].y *dsB.y ;
934+ }
935+ }
936+ }
900937#else
938+ typedef tile<16 , 8 , int > tile_A;
901939 typedef tile< 8 , 8 , int > tile_B;
902940 typedef tile<16 , 8 , int > tile_C;
903- #endif
904941
905942 constexpr int granularity = mmq_get_granularity_device (GGML_TYPE_Q8_0, mmq_x);
906- #if defined(AMD_MMA_AVAILABLE)
907- constexpr int rows_per_warp = granularity; // 16
908- #else
909943 constexpr int rows_per_warp = 2 * granularity;
910- #endif
911944 constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
912945
913946 y += (threadIdx .y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
@@ -918,11 +951,7 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
918951 const half2 * y_dm = (const half2 *) y;
919952
920953 tile_A A[ntx][MMQ_TILE_NE_K/QI8_1];
921- #if defined(AMD_MMA_AVAILABLE)
922- float2 dmA[ntx][tile_C::ne][MMQ_TILE_NE_K/QI8_1];
923- #else
924954 float2 dmA[ntx][tile_C::ne/2 ][MMQ_TILE_NE_K/QI8_1];
925- #endif
926955
927956 const int i0 = (threadIdx .y /ntx)*rows_per_warp;
928957
@@ -936,13 +965,8 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
936965 }
937966
938967#pragma unroll
939- #if defined(AMD_MMA_AVAILABLE)
940- for (int l = 0 ; l < tile_C::ne; ++l) {
941- const int i = i0 + n*tile_A::I + tile_C::get_i (l);
942- #else
943968 for (int l = 0 ; l < tile_C::ne/2 ; ++l) {
944969 const int i = i0 + n*tile_A::I + tile_C::get_i (2 *l);
945- #endif
946970
947971#pragma unroll
948972 for (int k01 = 0 ; k01 < MMQ_TILE_NE_K; k01 += QI8_1) {
@@ -958,25 +982,16 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
958982#pragma unroll
959983 for (int k01 = 0 ; k01 < MMQ_TILE_NE_K; k01 += QI8_1) {
960984 tile_B B;
961- #if defined(AMD_MMA_AVAILABLE)
962- load_ldmatrix (B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); // faster than load_ldmatrix
963- #else
964- load_generic (B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); // faster than load_ldmatrix
965- #endif
966-
967- #if defined(AMD_MMA_AVAILABLE)
968- float2 dsB;
969- const int j = j0 + tile_C::get_j (0 );
970- dsB = __half22float2 (y_dm[j*MMQ_TILE_Y_K + k01/QI8_1]);
971- #else
972985 float2 dsB[tile_C::ne/2 ];
986+
987+ load_generic (B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); // faster than load_ldmatrix
988+
973989#pragma unroll
974990 for (int l = 0 ; l < tile_C::ne/2 ; ++l) {
975991 const int j = j0 + tile_C::get_j (l);
976992
977993 dsB[l] = __half22float2 (y_dm[j*MMQ_TILE_Y_K + k01/QI8_1]);
978994 }
979- #endif
980995
981996#pragma unroll
982997 for (int n = 0 ; n < ntx; ++n) {
@@ -985,17 +1000,13 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
9851000
9861001#pragma unroll
9871002 for (int l = 0 ; l < tile_C::ne; ++l) {
988- #if defined(AMD_MMA_AVAILABLE)
989- sum[(j0/tile_C::J + n)*tile_C::ne + l] += dmA[n][l][k01/QI8_1].x *dsB.x *C.x [l];
990- sum[(j0/tile_C::J + n)*tile_C::ne + l] += dmA[n][l][k01/QI8_1].y *dsB.y ;
991- #else
9921003 sum[(j0/tile_C::J + n)*tile_C::ne + l] += dmA[n][l/2 ][k01/QI8_1].x *dsB[l%2 ].x *C.x [l];
9931004 sum[(j0/tile_C::J + n)*tile_C::ne + l] += dmA[n][l/2 ][k01/QI8_1].y *dsB[l%2 ].y ;
994- #endif
9951005 }
9961006 }
9971007 }
9981008 }
1009+ #endif // AMD_MMA_AVAILABLE
9991010}
10001011
10011012template <int mmq_x, int mmq_y>
0 commit comments