@@ -123,7 +123,7 @@ static constexpr __device__ int get_mmq_x_max_device() {
123123
124124#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
125125#endif // defined(NEW_MMA_AVAILABLE)
126- #endif // defined(AMD_MMA_AVAILABLE)
126+ #endif // defined(AMD_MMA_AVAILABLE)
127127}
128128
129129static int get_mmq_y_host (const int cc) {
@@ -231,21 +231,21 @@ static int mmq_get_granularity_host(ggml_type type, const int mmq_x, const int c
231231 if (amd_mma_available (cc)) {
232232 switch (type) {
233233 // vec_dot_q8_0_q8_1_mma
234- case GGML_TYPE_Q4_0:
235- case GGML_TYPE_Q5_0:
236- case GGML_TYPE_Q8_0:
237- case GGML_TYPE_IQ2_XXS:
238- case GGML_TYPE_IQ3_XXS:
239- case GGML_TYPE_IQ3_S:
240- case GGML_TYPE_IQ4_XS:
241- case GGML_TYPE_IQ4_NL:
234+ case GGML_TYPE_Q4_0:
235+ case GGML_TYPE_Q5_0:
236+ case GGML_TYPE_Q8_0:
237+ case GGML_TYPE_IQ2_XXS:
238+ case GGML_TYPE_IQ3_XXS:
239+ case GGML_TYPE_IQ3_S:
240+ case GGML_TYPE_IQ4_XS:
241+ case GGML_TYPE_IQ4_NL:
242242 return mmq_x >= 128 ? 32 : 16 ;
243243 // vec_dot_q8_1_q8_1_mma
244- case GGML_TYPE_Q4_1:
245- case GGML_TYPE_Q5_1:
246- case GGML_TYPE_Q4_K:
247- case GGML_TYPE_Q5_K:
248- case GGML_TYPE_IQ1_S:
244+ case GGML_TYPE_Q4_1:
245+ case GGML_TYPE_Q5_1:
246+ case GGML_TYPE_Q4_K:
247+ case GGML_TYPE_Q5_K:
248+ case GGML_TYPE_IQ1_S:
249249 return mmq_x >= 128 ? 32 : 16 ;
250250 case GGML_TYPE_Q2_K: // vec_dot_q2_K_q8_1_mma
251251 case GGML_TYPE_Q3_K: // vec_dot_q8_0_16_q8_1_mma
@@ -265,21 +265,21 @@ static int mmq_get_granularity_host(ggml_type type, const int mmq_x, const int c
265265static constexpr __device__ int mmq_get_granularity_device (ggml_type type, const int mmq_x) {
266266 switch (type) {
267267 // vec_dot_q8_0_q8_1_mma
268- case GGML_TYPE_Q4_0:
269- case GGML_TYPE_Q5_0:
270- case GGML_TYPE_Q8_0:
271- case GGML_TYPE_IQ2_XXS:
272- case GGML_TYPE_IQ3_XXS:
273- case GGML_TYPE_IQ3_S:
274- case GGML_TYPE_IQ4_XS:
275- case GGML_TYPE_IQ4_NL:
268+ case GGML_TYPE_Q4_0:
269+ case GGML_TYPE_Q5_0:
270+ case GGML_TYPE_Q8_0:
271+ case GGML_TYPE_IQ2_XXS:
272+ case GGML_TYPE_IQ3_XXS:
273+ case GGML_TYPE_IQ3_S:
274+ case GGML_TYPE_IQ4_XS:
275+ case GGML_TYPE_IQ4_NL:
276276 return mmq_x >= 128 ? 32 : 16 ;
277277 // vec_dot_q8_1_q8_1_mma
278- case GGML_TYPE_Q4_1:
279- case GGML_TYPE_Q5_1:
280- case GGML_TYPE_Q4_K:
281- case GGML_TYPE_Q5_K:
282- case GGML_TYPE_IQ1_S:
278+ case GGML_TYPE_Q4_1:
279+ case GGML_TYPE_Q5_1:
280+ case GGML_TYPE_Q4_K:
281+ case GGML_TYPE_Q5_K:
282+ case GGML_TYPE_IQ1_S:
283283 return mmq_x >= 128 ? 32 : 16 ;
284284 case GGML_TYPE_Q2_K: // vec_dot_q2_K_q8_1_mma
285285 case GGML_TYPE_Q3_K: // vec_dot_q8_0_16_q8_1_mma
@@ -305,21 +305,21 @@ static int get_mmq_nwarps_host(ggml_type type, const int cc) {
305305 if (amd_mma_available (cc)) {
306306 switch (type) {
307307 // vec_dot_q8_0_q8_1_mma
308- case GGML_TYPE_Q4_0:
309- case GGML_TYPE_Q5_0:
310- case GGML_TYPE_Q8_0:
311- case GGML_TYPE_IQ2_XXS:
312- case GGML_TYPE_IQ3_XXS:
313- case GGML_TYPE_IQ3_S:
314- case GGML_TYPE_IQ4_XS:
315- case GGML_TYPE_IQ4_NL:
308+ case GGML_TYPE_Q4_0:
309+ case GGML_TYPE_Q5_0:
310+ case GGML_TYPE_Q8_0:
311+ case GGML_TYPE_IQ2_XXS:
312+ case GGML_TYPE_IQ3_XXS:
313+ case GGML_TYPE_IQ3_S:
314+ case GGML_TYPE_IQ4_XS:
315+ case GGML_TYPE_IQ4_NL:
316316 return 8 ;
317317 // vec_dot_q8_1_q8_1_mma
318- case GGML_TYPE_Q4_1:
319- case GGML_TYPE_Q5_1:
320- case GGML_TYPE_Q4_K:
321- case GGML_TYPE_Q5_K:
322- case GGML_TYPE_IQ1_S:
318+ case GGML_TYPE_Q4_1:
319+ case GGML_TYPE_Q5_1:
320+ case GGML_TYPE_Q4_K:
321+ case GGML_TYPE_Q5_K:
322+ case GGML_TYPE_IQ1_S:
323323 return 8 ;
324324 case GGML_TYPE_Q2_K: // vec_dot_q2_K_q8_1_mma
325325 case GGML_TYPE_Q3_K: // vec_dot_q8_0_16_q8_1_mma
@@ -339,21 +339,21 @@ static int get_mmq_nwarps_host(ggml_type type, const int cc) {
339339static constexpr __device__ int get_mmq_nwarps_device (ggml_type type) {
340340 switch (type) {
341341 // vec_dot_q8_0_q8_1_mma
342- case GGML_TYPE_Q4_0:
343- case GGML_TYPE_Q5_0:
344- case GGML_TYPE_Q8_0:
345- case GGML_TYPE_IQ2_XXS:
346- case GGML_TYPE_IQ3_XXS:
347- case GGML_TYPE_IQ3_S:
348- case GGML_TYPE_IQ4_XS:
349- case GGML_TYPE_IQ4_NL:
342+ case GGML_TYPE_Q4_0:
343+ case GGML_TYPE_Q5_0:
344+ case GGML_TYPE_Q8_0:
345+ case GGML_TYPE_IQ2_XXS:
346+ case GGML_TYPE_IQ3_XXS:
347+ case GGML_TYPE_IQ3_S:
348+ case GGML_TYPE_IQ4_XS:
349+ case GGML_TYPE_IQ4_NL:
350350 return 8 ;
351351 // vec_dot_q8_1_q8_1_mma
352- case GGML_TYPE_Q4_1:
353- case GGML_TYPE_Q5_1:
354- case GGML_TYPE_Q4_K:
355- case GGML_TYPE_Q5_K:
356- case GGML_TYPE_IQ1_S:
352+ case GGML_TYPE_Q4_1:
353+ case GGML_TYPE_Q5_1:
354+ case GGML_TYPE_Q4_K:
355+ case GGML_TYPE_Q5_K:
356+ case GGML_TYPE_IQ1_S:
357357 return 8 ;
358358 case GGML_TYPE_Q2_K: // vec_dot_q2_K_q8_1_mma
359359 case GGML_TYPE_Q3_K: // vec_dot_q8_0_16_q8_1_mma
@@ -851,7 +851,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
851851
852852 for (int k01 = 0 ; k01 < MMQ_TILE_NE_K; k01 += QI8_0) {
853853 const int k0 = k00 + k01;
854-
854+
855855 tile_A A[ntx];
856856#pragma unroll
857857 for (int n = 0 ; n < ntx; ++n) {
@@ -1019,7 +1019,7 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
10191019
10201020 for (int k01 = 0 ; k01 < MMQ_TILE_NE_K; k01 += QI8_1) {
10211021 const int k0 = k00 + k01;
1022-
1022+
10231023 tile_A A[ntx];
10241024#pragma unroll
10251025 for (int n = 0 ; n < ntx; ++n) {
@@ -1101,7 +1101,7 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
11011101 float2 dsB[tile_C::ne/2 ];
11021102
11031103 load_generic (B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); // faster than load_ldmatrix
1104-
1104+
11051105#pragma unroll
11061106 for (int l = 0 ; l < tile_C::ne/2 ; ++l) {
11071107 const int j = j0 + tile_C::get_j (l);
@@ -1258,7 +1258,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
12581258
12591259 for (int k01 = 0 ; k01 < MMQ_TILE_NE_K; k01 += 4 ) {
12601260 const int k0 = k00 + k01;
1261-
1261+
12621262 tile_A A[ntx];
12631263#pragma unroll
12641264 for (int n = 0 ; n < ntx; ++n) {
@@ -1272,7 +1272,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
12721272 float dB;
12731273 const int j = j0 + tile_C::get_j (0 );
12741274 dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
1275-
1275+
12761276#pragma unroll
12771277 for (int n = 0 ; n < ntx; ++n) {
12781278 tile_C C;
@@ -1557,7 +1557,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
15571557
15581558 for (int k01 = 0 ; k01 < MMQ_TILE_NE_K; k01 += 4 ) {
15591559 const int k0 = k00 + k01;
1560-
1560+
15611561 tile_A A[ntx];
15621562#pragma unroll
15631563 for (int n = 0 ; n < ntx; ++n) {
@@ -1571,8 +1571,8 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
15711571 float dB, sB ;
15721572 const int j = j0 + tile_C::get_j (0 );
15731573 dB = (k01 < MMQ_TILE_NE_K/2 ) ? __half22float2 (y_ds[j*MMQ_TILE_Y_K]).x : __half22float2 (y_ds[j*MMQ_TILE_Y_K]).y ;
1574- sB = (k01 >= MMQ_TILE_NE_K * 3 /4 ) ? 0
1575- : (((k01/4 )%2 ) ? __half22float2 (y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]).y
1574+ sB = (k01 >= MMQ_TILE_NE_K * 3 /4 ) ? 0
1575+ : (((k01/4 )%2 ) ? __half22float2 (y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]).y
15761576 : __half22float2 (y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]).x );
15771577
15781578 tile_C Cm;
@@ -2060,7 +2060,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
20602060#endif // NEW_MMA_AVAILABLE
20612061
20622062 constexpr int threads_per_row = MMQ_ITER_K / (4 * QR6_K);
2063- constexpr int nrows = warp_size / threads_per_row;
2063+ constexpr int nrows = warp_size / threads_per_row;
20642064 const int txi = threadIdx .x % threads_per_row;
20652065
20662066#pragma unroll
@@ -2291,7 +2291,7 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
22912291
22922292 for (int k01 = 0 ; k01 < MMQ_TILE_NE_K; k01 += 4 ) {
22932293 const int k0 = k00 + k01;
2294-
2294+
22952295 tile_A A[ntx];
22962296#pragma unroll
22972297 for (int n = 0 ; n < ntx; ++n) {
@@ -2358,8 +2358,8 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
23582358
23592359 const int aux_q4 = get_int_b2 (bxi->qs , kqsx);
23602360 const int2 v = get_int_from_table_16 (aux_q4);
2361- const int k0 = kbx * (2 * QI4_NL) + kqsx;
2362-
2361+ const int k0 = kbx * (2 * QI4_NL) + kqsx;
2362+
23632363#if defined(AMD_MMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
23642364 x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0 ] = v.x ;
23652365 x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + QI4_NL] = v.y ;
@@ -2457,7 +2457,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
24572457 const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
24582458 constexpr int nwarps = get_mmq_nwarps_device (GGML_TYPE_IQ2_XS);
24592459 constexpr int warp_size = ggml_cuda_get_physical_warp_size ();
2460-
2460+
24612461#if defined(AMD_MMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
24622462 int * x_qs = (int *) x_tile;
24632463 float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2 );
@@ -2584,7 +2584,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
25842584 const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
25852585 constexpr int nwarps = get_mmq_nwarps_device (GGML_TYPE_IQ3_XXS);
25862586 constexpr int warp_size = ggml_cuda_get_physical_warp_size ();
2587-
2587+
25882588#if defined(AMD_MMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
25892589 int * x_qs = (int *) x_tile;
25902590 float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2 );
@@ -2644,7 +2644,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
26442644 const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
26452645 constexpr int nwarps = get_mmq_nwarps_device (GGML_TYPE_IQ3_S);
26462646 constexpr int warp_size = ggml_cuda_get_physical_warp_size ();
2647-
2647+
26482648#if defined(AMD_MMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
26492649 int * x_qs = (int *) x_tile;
26502650 float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2 );
@@ -2711,7 +2711,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
27112711 const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
27122712 constexpr int nwarps = get_mmq_nwarps_device (GGML_TYPE_IQ3_S);
27132713 constexpr int warp_size = ggml_cuda_get_physical_warp_size ();
2714-
2714+
27152715#if defined(AMD_MMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
27162716 int * x_qs = (int *) x_tile;
27172717 half2 * x_ds = (half2 *) (x_qs + MMQ_TILE_NE_K*2 );
@@ -2836,7 +2836,7 @@ template<int mmq_x, int mmq_y, bool need_check>
28362836static __device__ __forceinline__ void mmq_write_back_dp4a (
28372837 const float * __restrict__ sum, const int32_t * __restrict__ ids_dst, float * __restrict__ dst,
28382838 const int stride, const int i_max, const int j_max) {
2839- constexpr int nwarps = get_mmq_nwarps_device (GGML_TYPE_Q8_0); // Always 8
2839+ constexpr int nwarps = get_mmq_nwarps_device (GGML_TYPE_Q8_0); // Always 8
28402840 constexpr int warp_size = ggml_cuda_get_physical_warp_size ();
28412841
28422842#pragma unroll
0 commit comments