@@ -909,13 +909,6 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
909909 tile_A A;
910910 load_ldmatrix (A, x_qs + i0*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1);
911911
912- float2 dmA[tile_C::ne];
913- #pragma unroll
914- for (int l = 0 ; l < tile_C::ne; ++l) {
915- const int i = i0 + tile_C::get_i (l);
916- dmA[l] = __half22float2 (x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + k0/QI8_1]);
917- }
918-
919912#pragma unroll
920913 for (int j0 = 0 ; j0 < mmq_x; j0 += tile_C::J) {
921914 tile_B B;
@@ -929,8 +922,9 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
929922 mma (C, A, B);
930923
931924 for (int l = 0 ; l < tile_C::ne; ++l) {
932- sum[(j0/tile_C::J)*tile_C::ne + l] += dmA[l].x *dsB.x *C.x [l];
933- sum[(j0/tile_C::J)*tile_C::ne + l] += dmA[l].y *dsB.y ;
925+ float2 dmA = __half22float2 (x_dm[(i0 + tile_C::get_i (l))*MMQ_MMA_TILE_X_K_Q8_1 + k0/QI8_1]);
926+ sum[(j0/tile_C::J)*tile_C::ne + l] += dmA.x *dsB.x *C.x [l];
927+ sum[(j0/tile_C::J)*tile_C::ne + l] += dmA.y *dsB.y ;
934928 }
935929 }
936930 }
@@ -2081,24 +2075,12 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
20812075
20822076 const int i0 = threadIdx .y * tile_A::I;
20832077
2084- int scA[tile_C::ne][2 ];
2085- float dA[tile_C::ne];
2086-
2087- #pragma unroll
2088- for (int l = 0 ; l < tile_C::ne; ++l) {
2089- const int i = i0 + tile_C::get_i (l);
2090- scA[l][0 ] = x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + k00/16 + 0 ];
2091- scA[l][1 ] = x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + k00/16 + 1 ];
2092- dA[l] = x_df[i*MMQ_MMA_TILE_X_K_Q6_K];
2093- }
2094-
20952078 for (int k01 = 0 ; k01 < MMQ_TILE_NE_K; k01 += 4 ) {
20962079 const int k0 = k00 + k01;
20972080
20982081 tile_A A;
20992082 load_ldmatrix (A, x_qs + i0*MMQ_MMA_TILE_X_K_Q6_K + k0, MMQ_MMA_TILE_X_K_Q6_K);
21002083
2101- #pragma unroll
21022084 for (int j0 = 0 ; j0 < mmq_x; j0 += tile_C::J) {
21032085 tile_B B;
21042086 load_ldmatrix (B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
@@ -2111,8 +2093,8 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
21112093 mma (C, A, B);
21122094
21132095 for (int l = 0 ; l < tile_C::ne; ++l) {
2114- const int8_t * sc = (const int8_t *) scA[l] ;
2115- sum[(j0/tile_C::J)*tile_C::ne + l] += C.x [l] * sc[k01/4 ] * dA[l ] * dB;
2096+ const int8_t * sc = (const int8_t *) (x_sc + (i0 + tile_C::get_i (l))*MMQ_MMA_TILE_X_K_Q6_K + k00/ 16 ) ;
2097+ sum[(j0/tile_C::J)*tile_C::ne + l] += C.x [l] * sc[k01/4 ] * x_df[(i0 + tile_C::get_i (l))*MMQ_MMA_TILE_X_K_Q6_K ] * dB;
21162098 }
21172099 }
21182100 }
@@ -2858,9 +2840,9 @@ static __device__ __forceinline__ void mul_mat_q_process_tile(
28582840 const int * by0 = y + ncols_y*(kb0*(qk*sizeof (block_q8_1_mmq) / (4 *QK8_1*sizeof (int ))) + 0 *sizeof (block_q8_1_mmq)/sizeof (int ));
28592841#pragma unroll
28602842 for (int l0 = 0 ; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*warp_size) {
2861- int l = l0 + threadIdx .y *warp_size + threadIdx .x ;
2843+ int l = ( l0 + threadIdx .y *warp_size + threadIdx .x ) % (mmq_x*MMQ_TILE_Y_K) ;
28622844
2863- if (l < mmq_x*MMQ_TILE_Y_K) tile_y[l] = by0[l];
2845+ tile_y[l] = by0[l];
28642846 }
28652847 }
28662848
@@ -2874,9 +2856,9 @@ static __device__ __forceinline__ void mul_mat_q_process_tile(
28742856 const int * by0 = y + ncols_y*(kb0*(qk*sizeof (block_q8_1_mmq) / (4 *QK8_1*sizeof (int ))) + 1 *sizeof (block_q8_1_mmq)/sizeof (int ));
28752857#pragma unroll
28762858 for (int l0 = 0 ; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*warp_size) {
2877- int l = l0 + threadIdx .y *warp_size + threadIdx .x ;
2859+ int l = ( l0 + threadIdx .y *warp_size + threadIdx .x ) % (mmq_x*MMQ_TILE_Y_K) ;
28782860
2879- if (l < mmq_x*MMQ_TILE_Y_K) tile_y[l] = by0[l];
2861+ tile_y[l] = by0[l];
28802862 }
28812863 }
28822864
@@ -2899,9 +2881,7 @@ static __device__ __forceinline__ void mul_mat_q_process_tile(
28992881
29002882template <ggml_type type, int mmq_x, int warp_size, bool need_check>
29012883#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
2902- #if defined(AMD_MMA_AVAILABLE)
2903- __launch_bounds__ (warp_size*get_mmq_nwarps_device (type), 1)
2904- #elif defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(CDNA2) || defined(CDNA1) || defined(GCN)
2884+ #if defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(CDNA3) || defined(CDNA2) || defined(CDNA1) || defined(GCN)
29052885 __launch_bounds__ (warp_size*get_mmq_nwarps_device (type), 2)
29062886#endif // defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
29072887#else
0 commit comments