diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index 181f179ed171c..c2727abd4c8b0 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -210,6 +210,7 @@ option(GGML_HIP_GRAPHS "ggml: use HIP graph, experimental, option(GGML_HIP_NO_VMM "ggml: do not try to use HIP VMM" ON) option(GGML_HIP_ROCWMMA_FATTN "ggml: enable rocWMMA for FlashAttention" OFF) option(GGML_HIP_MMQ_MFMA "ggml: enable MFMA MMA for CDNA in MMQ" ON) +option(GGML_HIP_MMQ_WMMA "ggml: enable WMMA MMA for RDNA4 in MMQ" ON) option(GGML_HIP_EXPORT_METRICS "ggml: enable kernel perf metrics output" OFF) option(GGML_MUSA_GRAPHS "ggml: use MUSA graph, experimental, unstable" OFF) option(GGML_MUSA_MUDNN_COPY "ggml: enable muDNN for accelerated copy" OFF) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index ca876459d404d..edfdbf6555bdf 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -228,6 +228,9 @@ static const char * cu_get_error_str(CUresult err) { #if !defined(GGML_USE_HIP) && __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA #define VOLTA_MMA_AVAILABLE #endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA +#if defined(GGML_USE_HIP) && defined(RDNA4) +#define AMD_WMMA_AVAILABLE +#endif // defined(GGML_USE_HIP) && defined(RDNA4) && !defined(GGML_HIP_NO_MMQ_WMMA) #if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING #define TURING_MMA_AVAILABLE @@ -287,6 +290,10 @@ static bool volta_mma_available(const int cc) { return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_VOLTA; } +static bool amd_wmma_available(const int cc) { + return GGML_CUDA_CC_IS_RDNA4(cc); +} + static bool turing_mma_available(const int cc) { return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_TURING; } diff --git a/ggml/src/ggml-cuda/mma.cuh b/ggml/src/ggml-cuda/mma.cuh index a7a28fd1ae660..9dd48c7535773 100644 --- a/ggml/src/ggml-cuda/mma.cuh +++ b/ggml/src/ggml-cuda/mma.cuh @@ -73,7 +73,7 @@ namespace ggml_cuda_mma { static constexpr int I = I_; static constexpr int J = J_; -#if defined(GGML_USE_HIP) +#if defined(AMD_MFMA_AVAILABLE) static constexpr int ne = I * J / 64; T x[ne] = {0}; @@ -149,6 +149,28 @@ namespace ggml_cuda_mma { return -1; } } +#elif defined(AMD_WMMA_AVAILABLE) //adjusted the mapping for RDNA 4 + + static constexpr int ne = I * J / 32; + T x[ne] = {0}; + + static __device__ __forceinline__ int get_i(const int l) { + if constexpr (I == 16 && J == 16) { + return 8 * (threadIdx.x / 16) + l; + } else { + NO_DEVICE_CODE; + return -1; + } + } + + static __device__ __forceinline__ int get_j(const int l) { + if constexpr (I == 16 && J == 16) { + return threadIdx.x % 16; + } else { + NO_DEVICE_CODE; + return -1; + } + } #else static constexpr int ne = I * J / 32; T x[ne] = {0}; @@ -353,6 +375,22 @@ namespace ggml_cuda_mma { const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I)); xi[0] = xs[0]; } + +#elif defined(AMD_WMMA_AVAILABLE) + if constexpr (I == 16 && J == 4) { + int64_t * xi = (int64_t *) t.x; + const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I)); + xi[0] = xs[0]; + }else if constexpr (I == 16 && J == 8) { + int64_t * xi = (int64_t *) t.x; + const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 4 * (threadIdx.x / t.I)); + xi[0] = xs[0]; + + const int64_t * xs1 = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 4 * (threadIdx.x / t.I) + 2); + xi[1] = xs1[0]; + }else{ + NO_DEVICE_CODE; + } #else #pragma unroll for (int l = 0; l < t.ne; ++l) { @@ -665,6 +703,36 @@ namespace ggml_cuda_mma { acc[0], 0, 0, 0); #endif // defined(CDNA3) + +#elif defined(AMD_WMMA_AVAILABLE) + using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int; + int32x2_t * a_vec = (int32x2_t *) A.x; + int32x2_t * b_vec = (int32x2_t *) B.x; + + using int32x8_t = __attribute__((__vector_size__(8 * sizeof(int)))) int; + int32x8_t * acc = (int32x8_t *) D.x; + +#if defined(RDNA4) + + acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12( + true, + a_vec[0], + true, + b_vec[0], + acc[0], + true + ); + + acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12( + true, + a_vec[1], + true, + b_vec[1], + acc[0], + true + ); +#endif // defined(RDNA4) + #else GGML_UNUSED_VARS(D, A, B); NO_DEVICE_CODE; @@ -691,6 +759,7 @@ namespace ggml_cuda_mma { acc[0], 0, 0, 0); #endif // defined(CDNA3) + #else GGML_UNUSED_VARS(D, A, B); NO_DEVICE_CODE; @@ -735,4 +804,31 @@ namespace ggml_cuda_mma { mma(D16[1], A16[1], B); #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE } + +static __device__ __forceinline__ void mma( + tile<16, 16, int> & D, const tile<16, 4, int> & A, const tile<16, 4, int> & B) { +#if defined(AMD_WMMA_AVAILABLE) + using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int; + int32x2_t * a_vec = (int32x2_t *) A.x; + int32x2_t * b_vec = (int32x2_t *) B.x; + + using int32x8_t = __attribute__((__vector_size__(8 * sizeof(int)))) int; + int32x8_t * acc = (int32x8_t *) D.x; + + acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12( + true, + a_vec[0], + true, + b_vec[0], + acc[0], + false + ); +#else + GGML_UNUSED(D); + GGML_UNUSED(A); + GGML_UNUSED(B); + NO_DEVICE_CODE; +#endif + } } + diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu index a2c8760abea93..ece1ae4391231 100644 --- a/ggml/src/ggml-cuda/mmq.cu +++ b/ggml/src/ggml-cuda/mmq.cu @@ -306,5 +306,12 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) { return false; } - return (!GGML_CUDA_CC_IS_RDNA4(cc) && !GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE; + if (amd_wmma_available(cc)) { + if (GGML_CUDA_CC_IS_RDNA4(cc)) { + return true; + } + + } + + return (!GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE; } diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index c9a07e82fedf2..339470b44c449 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -92,7 +92,7 @@ struct tile_x_sizes { }; static int get_mmq_x_max_host(const int cc) { - return (amd_mfma_available(cc) || turing_mma_available(cc)) ? 128 : + return (amd_mfma_available(cc) || turing_mma_available(cc) || amd_wmma_available(cc)) ? 128 : GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA ? #ifdef GGML_CUDA_FORCE_MMQ 128 : 64; @@ -102,7 +102,7 @@ static int get_mmq_x_max_host(const int cc) { } static constexpr __device__ int get_mmq_x_max_device() { -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) return 128; #else // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) @@ -121,7 +121,7 @@ static constexpr __device__ int get_mmq_x_max_device() { #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA #endif // defined(GGML_USE_HIP) -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } static int get_mmq_y_host(const int cc) { @@ -231,7 +231,7 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) { #define MMQ_TILE_Y_K (MMQ_TILE_NE_K + MMQ_TILE_NE_K/QI8_1) static int mmq_get_granularity_host(const int mmq_x, const int cc) { - if (amd_mfma_available(cc)) { + if (amd_mfma_available(cc) || amd_wmma_available(cc)) { return mmq_x >= 128 ? 32 : 16; } else if (turing_mma_available(cc) && mmq_x >= 48) { return 16; @@ -240,7 +240,7 @@ static int mmq_get_granularity_host(const int mmq_x, const int cc) { } } -#if defined(AMD_MFMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) static constexpr __device__ int mmq_get_granularity_device(const int mmq_x) { return mmq_x >= 128 ? 32 : 16; } @@ -265,7 +265,7 @@ static int mmq_get_nwarps_host(const int /*cc*/, const int warp_size) { #endif // (GGML_USE_HIP) static constexpr __device__ int mmq_get_nwarps_device() { -#if defined(AMD_MFMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) return 8; #else return 256/ggml_cuda_get_physical_warp_size(); @@ -279,14 +279,14 @@ template static __device__ __forceinline__ void loa constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + 2*MMQ_TILE_NE_K); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_0); constexpr int nrows = warp_size / threads_per_row; @@ -305,7 +305,7 @@ template static __device__ __forceinline__ void loa const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbx; const int qs0 = get_int_b2(bxi->qs, kqsx); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + 0] = __vsubss4((qs0 >> 0) & 0x0F0F0F0F, 0x08080808); x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + QI4_0] = __vsubss4((qs0 >> 4) & 0x0F0F0F0F, 0x08080808); #else @@ -327,11 +327,11 @@ template static __device__ __forceinline__ void loa const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbxd; -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d; #else x_df[i*(MMQ_TILE_NE_K/QI4_0) + i/QI4_0 + kbxd] = bxi->d; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } } @@ -382,14 +382,14 @@ template static __device__ __forceinline__ void loa constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y); int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + txs.qs); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_1); constexpr int nrows = warp_size / threads_per_row; @@ -408,12 +408,12 @@ template static __device__ __forceinline__ void loa const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbx; const int qs0 = get_int_b4(bxi->qs, kqsx); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + 0] = (qs0 >> 0) & 0x0F0F0F0F; x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + QI4_1] = (qs0 >> 4) & 0x0F0F0F0F; #else x_qs[i*(MMQ_TILE_NE_K + 1) + txi] = qs0; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI4_1; @@ -430,11 +430,11 @@ template static __device__ __forceinline__ void loa const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbxd; -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm; #else x_dm[i*(MMQ_TILE_NE_K/QI4_1) + i/QI4_1 + kbxd] = bxi->dm; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } } @@ -485,14 +485,14 @@ template static __device__ __forceinline__ void loa constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_0, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) constexpr int threads_per_row = MMQ_ITER_K / (4 * QR5_0); constexpr int nrows = warp_size / threads_per_row; @@ -527,13 +527,13 @@ template static __device__ __forceinline__ void loa qs1 |= (qh << 9) & 0x10000000; // 19 -> 28 qs1 = __vsubss4(qs1, 0x10101010); // subtract 16 -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + 0] = qs0; x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1; #else x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_0) + kqsx + 0] = qs0; x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI5_0; @@ -550,11 +550,11 @@ template static __device__ __forceinline__ void loa const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbxd; -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d; #else x_df[i*(MMQ_TILE_NE_K/QI5_0) + i/QI5_0 + kbxd] = bxi->d; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } } @@ -563,14 +563,14 @@ template static __device__ __forceinline__ void loa constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y); int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + txs.qs); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) constexpr int threads_per_row = MMQ_ITER_K / (4 * QR5_1); constexpr int nrows = warp_size / threads_per_row; @@ -603,13 +603,13 @@ template static __device__ __forceinline__ void loa qs1 |= (qh << 2) & 0x00100000; // 18 -> 20 qs1 |= (qh << 9) & 0x10000000; // 19 -> 28 -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + 0] = qs0; x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1; #else x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_1) + kqsx + 0] = qs0; x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI5_1; @@ -626,11 +626,11 @@ template static __device__ __forceinline__ void loa const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbxd; -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm; #else x_dm[i*(MMQ_TILE_NE_K/QI5_1) + i/QI5_1 + kbxd] = bxi->dm; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } } @@ -639,14 +639,14 @@ template static __device__ __forceinline__ void loa constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) int * x_qs = (int *) x_tile; float * x_df = (float *) (x_tile + 2*MMQ_TILE_NE_K); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) // MMQ_ITER_K / (4 * QR8_0) == 64 required. but NV has only 32 threads per warp constexpr int threads_per_row = 32; @@ -665,13 +665,13 @@ template static __device__ __forceinline__ void loa const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbx; -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 0 + txi] = get_int_b2(bxi[0].qs, kqsx); x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + MMQ_TILE_NE_K + txi] = get_int_b2(bxi[MMQ_TILE_NE_K/QI8_0].qs, kqsx); #else x_qs[i*(2*MMQ_TILE_NE_K + 1) + 0 + txi] = get_int_b2(bxi[0].qs, kqsx); x_qs[i*(2*MMQ_TILE_NE_K + 1) + MMQ_TILE_NE_K + txi] = get_int_b2(bxi[MMQ_TILE_NE_K/QI8_0].qs, kqsx); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } constexpr int blocks_per_tile_x_row = 2*MMQ_TILE_NE_K / QI8_0; @@ -688,11 +688,11 @@ template static __device__ __forceinline__ void loa const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbxd; -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d; #else x_df[i*(2*MMQ_TILE_NE_K/QI8_0) + i/(QI8_0/2) + kbxd] = bxi->d; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } } @@ -701,14 +701,14 @@ template static __device__ __forceinline__ void loa constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_MXFP4, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) constexpr int threads_per_row = MMQ_ITER_K / (4 * QR_MXFP4); constexpr int nrows = warp_size / threads_per_row; @@ -730,13 +730,13 @@ template static __device__ __forceinline__ void loa const int2 v = get_int_from_table_16(aux_q4, kvalues_mxfp4); const int k0 = kbx * (2 * QI_MXFP4) + kqsx; -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + k0 + 0] = v.x; x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + k0 + QI_MXFP4] = v.y; #else x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0] = v.x; x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + QI_MXFP4] = v.y; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI_MXFP4; @@ -753,11 +753,11 @@ template static __device__ __forceinline__ void loa const block_mxfp4 * bxi = (const block_mxfp4 *) x + kbx0 + i*stride + kbxd; -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_df[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = ggml_cuda_e8m0_to_fp32(bxi->e)*0.5f; #else x_df[i*(MMQ_TILE_NE_K/QI_MXFP4) + i/QI_MXFP4 + kbxd] = ggml_cuda_e8m0_to_fp32(bxi->e)*0.5f; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } } @@ -796,7 +796,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a( template static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { -#if defined(AMD_MFMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) typedef tile<16, 8, int> tile_A; typedef tile<16, 8, int> tile_B; typedef tile<16, 16, int> tile_C; @@ -927,7 +927,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma( } } } -#endif // defined(AMD_MFMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } template @@ -965,7 +965,7 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_dp4a( template static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { -#if defined(AMD_MFMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) typedef tile<16, 8, int> tile_A; typedef tile<16, 8, int> tile_B; typedef tile<16, 16, int> tile_C; @@ -1087,7 +1087,7 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma( } } } -#endif // defined(AMD_MFMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } // Used for Q3_K, IQ2_S, and IQ2_XS @@ -1170,6 +1170,54 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma( tile_C C; mma(C, A[n], B[0]); +#pragma unroll + for (int l = 0; l < tile_C::ne; ++l) { + const int i = i0 + n*tile_C::I + tile_C::get_i(l); + sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l] * x_df[i*MMQ_MMA_TILE_X_K_Q3_K + k0/4] * dB; + } + } + } + } +#elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles + typedef tile<16, 4, int> tile_A; + typedef tile<16, 4, int> tile_B; + typedef tile<16, 16, int> tile_C; + + constexpr int granularity = mmq_get_granularity_device(mmq_x); + constexpr int rows_per_warp = granularity; + constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. + + y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); + + const int * x_qs = (const int *) x; + const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2; + const int * y_qs = (const int *) y + 4; + const float * y_df = (const float *) y; + + const int i0 = (threadIdx.y / ntx) * rows_per_warp; + + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) { + const int k0 = k00 + k01; + + tile_A A[ntx]; +#pragma unroll + for (int n = 0; n < ntx; ++n) { + load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K); + } + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { + tile_B B; + load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); + + const int j = j0 + tile_C::get_j(0); + const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1]; + +#pragma unroll + for (int n = 0; n < ntx; ++n) { + tile_C C; + mma(C, A[n], B); + #pragma unroll for (int l = 0; l < tile_C::ne; ++l) { const int i = i0 + n*tile_C::I + tile_C::get_i(l); @@ -1257,21 +1305,21 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma( #else GGML_UNUSED_VARS(x, y, sum, k00); NO_DEVICE_CODE; -#endif // AMD_MFMA_AVAILABLE +#endif // AMD_MFMA_AVAILABLE || AMD_WMMA_AVAILABLE } template static __device__ __forceinline__ void load_tiles_q2_K( const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { constexpr int nwarps = mmq_get_nwarps_device(); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y); int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + txs.qs); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) constexpr int threads_per_row = MMQ_ITER_K / (4 * QR2_K); constexpr int nrows = ggml_cuda_get_physical_warp_size() / threads_per_row; @@ -1295,11 +1343,11 @@ template static __device__ __forceinline__ void loa const int x_qs_k = (x_ql_0 >> (2*l)) & 0x03030303; -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q2_K + k] = x_qs_k; #else x_qs[i*(2*MMQ_TILE_NE_K + 1) + k] = x_qs_k; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } const int sc_m = bxi->scales[kqsx]; @@ -1310,11 +1358,11 @@ template static __device__ __forceinline__ void loa const half2 x_dm_ik = make_half2(bxi_dmf.x*(sc_m & 0x0F), bxi_dmf.y*(sc_m >> 4)); #endif // FAST_FP16_AVAILABLE -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + kqsx] = x_dm_ik; #else x_dm[i*(MMQ_TILE_NE_K + 1) + kqsx] = x_dm_ik; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } } @@ -1438,6 +1486,72 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( tile_C Cd; mma(Cd, A[n], B[0]); +#pragma unroll + for (int l = 0; l < tile_C::ne; ++l) { + const int i = i0 + n*tile_C::I + tile_C::get_i(l); + const float2 dm = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + k0/4]); + float tmp = Cd.x[l]*dm.x; + if (k01 >= MMQ_TILE_NE_K * 3/4) { + tmp -= Cm.x[l]*dm.y; + } + sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp*dB; + sum[(j0/tile_C::J + n)*tile_C::ne + l] -= dm.y*sB; + } + } + } + } +#elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles + + typedef tile<16, 4, int> tile_A; + typedef tile<16, 4, int> tile_B; + typedef tile<16, 16, int> tile_C; + + constexpr int granularity = mmq_get_granularity_device(mmq_x); + constexpr int rows_per_warp = granularity; + constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. + + y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); + + const int * x_qs = (const int *) x; + const half2 * x_dm = (const half2 *) x_qs + MMQ_TILE_NE_K*2; + const int * y_qs = (const int *) y + 4; + const half2 * y_ds = (const half2 *) y; + + const int i0 = (threadIdx.y / ntx) * rows_per_warp; + + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) { + const int k0 = k00 + k01; + + tile_A A[ntx]; +#pragma unroll + for (int n = 0; n < ntx; ++n) { + load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K); + } + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { + tile_B B; + load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); + + const int j = j0 + tile_C::get_j(0); + const float dB = (k01 < MMQ_TILE_NE_K/2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K]).x : __half22float2(y_ds[j*MMQ_TILE_Y_K]).y; + const float sB = (k01 >= MMQ_TILE_NE_K * 3/4) ? 0 + : (((k01/4)%2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]).y + : __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]).x); + + tile_C Cm; + if (k01 >= MMQ_TILE_NE_K * 3/4) { + tile_A A1; + A1.x[0] = 0x01010101; + A1.x[1] = 0x01010101; + mma(Cm, A1, B); + } + +#pragma unroll + for (int n = 0; n < ntx; ++n) { + tile_C Cd; + mma(Cd, A[n], B); + #pragma unroll for (int l = 0; l < tile_C::ne; ++l) { const int i = i0 + n*tile_C::I + tile_C::get_i(l); @@ -1574,7 +1688,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( #else GGML_UNUSED_VARS(x, y, sum, k00); NO_DEVICE_CODE; -#endif // AMD_MFMA_AVAILABLE +#endif // AMD_MFMA_AVAILABLE || AMD_WMMA_AVAILABLE } template static __device__ __forceinline__ void load_tiles_q3_K( @@ -1582,7 +1696,7 @@ template static __device__ __forceinline__ void loa constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); #else @@ -1618,11 +1732,11 @@ template static __device__ __forceinline__ void loa const int x_qs_k = __vsubss4(x_ql_k | x_qh_k, 0x04040404); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + k] = x_qs_k; #else x_qs[i*(2*MMQ_TILE_NE_K + 1) + k] = x_qs_k; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } } @@ -1649,7 +1763,7 @@ template static __device__ __forceinline__ void loa const int sc = __vsubss4(sc_low | sc_high, 0x20202020); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) const int8_t * sc8 = (const int8_t *) ≻ const float d = bxi->d; @@ -1659,10 +1773,10 @@ template static __device__ __forceinline__ void loa } #else x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = sc; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } -#if !(defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)) +#if !(defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)) #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nwarps*warp_size) { int i = (i0 + threadIdx.y*warp_size + threadIdx.x) % mmq_y; @@ -1675,7 +1789,7 @@ template static __device__ __forceinline__ void loa x_df[i] = bxi->d; } -#endif // !(defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)) +#endif // !(defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)) || defined(AMD_WMMA_AVAILABLE) } template @@ -1728,7 +1842,7 @@ template static __device__ __forceinline__ void loa constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K); #else @@ -1736,7 +1850,7 @@ template static __device__ __forceinline__ void loa int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + txs.qs); int * x_sc = (int *) (x_dm + txs.dm); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_K); constexpr int nrows = warp_size / threads_per_row; @@ -1753,19 +1867,19 @@ template static __device__ __forceinline__ void loa const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride; const int qs0 = get_int_b4(bxi->qs, txi); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(txi/8) + txi % 8 + 0] = (qs0 >> 0) & 0x0F0F0F0F; x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(txi/8) + txi % 8 + 8] = (qs0 >> 4) & 0x0F0F0F0F; #else x_qs[i*(MMQ_TILE_NE_K + 1) + txi] = qs0; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) constexpr int rows_per_warp = warp_size / 2; #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) { -#if defined(AMD_MFMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) // Need if on AMD instead of % because warp_size == 64 // This causes double work and throughput loss (MI300X) // H100 loses about 100 t/s with 'if' condition over '%' @@ -1774,7 +1888,7 @@ template static __device__ __forceinline__ void loa #else int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/2) % mmq_y; { -#endif // defined(AMD_MFMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) if (need_check) { i = min(i, i_max); } @@ -1829,7 +1943,7 @@ template static __device__ __forceinline__ void loa x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = scales8; } -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } template @@ -1872,7 +1986,7 @@ template static __device__ __forceinline__ void loa constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + MMQ_TILE_NE_K*2); #else @@ -1908,16 +2022,16 @@ template static __device__ __forceinline__ void loa const int kq0 = ky - ky % (QI5_K/2) + txi % (QI5_K/4) + 0; const int kq1 = ky - ky % (QI5_K/2) + txi % (QI5_K/4) + QI5_K/4; -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq0] = ql0 | qh0; x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq1] = ql1 | qh1; #else x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq0] = ql0 | qh0; x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq1] = ql1 | qh1; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) constexpr int rows_per_warp = warp_size / 2; #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) { @@ -1930,7 +2044,7 @@ template static __device__ __forceinline__ void loa #else int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/2) % mmq_y; { -#endif // defined(AMD_MFMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) if (need_check) { i = min(i, i_max); } @@ -1986,7 +2100,7 @@ template static __device__ __forceinline__ void loa x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = scales8; } -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } template @@ -2029,7 +2143,7 @@ template static __device__ __forceinline__ void loa constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); int * x_sc = (int *) (x_df + MMQ_TILE_NE_K/QI6_K); @@ -2038,7 +2152,7 @@ template static __device__ __forceinline__ void loa int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); int * x_sc = (int *) (x_df + txs.dm); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) constexpr int threads_per_row = MMQ_ITER_K / (4 * QR6_K); constexpr int nrows = warp_size / threads_per_row; @@ -2065,13 +2179,13 @@ template static __device__ __forceinline__ void loa const int kq0 = 2*txi - txi % (QI6_K/2) + 0; const int kq1 = 2*txi - txi % (QI6_K/2) + QI6_K/2; -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq0] = __vsubss4(ql0 | qh0, 0x20202020); x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq1] = __vsubss4(ql1 | qh1, 0x20202020); #else x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq0] = __vsubss4(ql0 | qh0, 0x20202020); x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq1] = __vsubss4(ql1 | qh1, 0x20202020); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } #pragma unroll @@ -2084,11 +2198,11 @@ template static __device__ __forceinline__ void loa const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride; -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_df[i*MMQ_MMA_TILE_X_K_Q6_K] = bxi->d; #else x_df[i*(MMQ_TILE_NE_K/QI6_K) + i/QI6_K] = bxi->d; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } constexpr int rows_per_warp = warp_size / 4; @@ -2102,11 +2216,11 @@ template static __device__ __forceinline__ void loa const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + (threadIdx.x % (MMQ_TILE_NE_K/8)) / 4; -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + threadIdx.x%4] = get_int_b2(bxi->scales, threadIdx.x % (MMQ_TILE_NE_K/8)); #else x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + threadIdx.x%(MMQ_TILE_NE_K/8)] = get_int_b2(bxi->scales, threadIdx.x%(QI6_K/8)); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } } @@ -2190,6 +2304,56 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( tile_C C; mma(C, A[n], B[0]); +#pragma unroll + for (int l = 0; l < tile_C::ne; ++l) { + const int i = i0 + n*tile_C::I + tile_C::get_i(l); + const int8_t * sc = (const int8_t *) (x_sc + i*MMQ_MMA_TILE_X_K_Q6_K + k00/16); + sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l] * sc[k01/4] * x_df[i*MMQ_MMA_TILE_X_K_Q6_K] * dB; + } + } + } + } +#elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles + typedef tile<16, 4, int> tile_A; + typedef tile<16, 4, int> tile_B; + typedef tile<16, 16, int> tile_C; + + constexpr int granularity = mmq_get_granularity_device(mmq_x); + constexpr int rows_per_warp = granularity; + constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. + + y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); + + const int * x_qs = (const int *) x; + const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2; + const int * x_sc = (const int *) x_df + MMQ_TILE_NE_K/QI6_K; + const int * y_qs = (const int *) y + 4; + const float * y_df = (const float *) y; + + const int i0 = (threadIdx.y / ntx) * rows_per_warp; + + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) { + const int k0 = k00 + k01; + + tile_A A[ntx]; +#pragma unroll + for (int n = 0; n < ntx; ++n) { + load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + k0, MMQ_MMA_TILE_X_K_Q6_K); + } + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { + tile_B B; + load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); + + const int j = j0 + tile_C::get_j(0); + const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1]; + +#pragma unroll + for (int n = 0; n < ntx; ++n) { + tile_C C; + mma(C, A[n], B); + #pragma unroll for (int l = 0; l < tile_C::ne; ++l) { const int i = i0 + n*tile_C::I + tile_C::get_i(l); @@ -2303,7 +2467,7 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( #else GGML_UNUSED_VARS(x, y, sum, k00); NO_DEVICE_CODE; -#endif // AMD_MFMA_AVAILABLE +#endif // AMD_MFMA_AVAILABLE || AMD_WMMA_AVAILABLE } template static __device__ __forceinline__ void load_tiles_iq4_nl( @@ -2311,14 +2475,14 @@ template static __device__ __forceinline__ void loa constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_NL, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_NL); constexpr int nrows = warp_size / threads_per_row; @@ -2340,13 +2504,13 @@ template static __device__ __forceinline__ void loa const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl); const int k0 = kbx * (2 * QI4_NL) + kqsx; -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x; x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + QI4_NL] = v.y; #else x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0] = v.x; x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + QI4_NL] = v.y; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI4_NL; @@ -2363,11 +2527,11 @@ template static __device__ __forceinline__ void loa const block_iq4_nl * bxi = (const block_iq4_nl *) x + kbx0 + i*stride + kbxd; -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = __half2float(bxi->d); #else x_df[i*(MMQ_TILE_NE_K/QI4_NL) + i/QI4_NL + kbxd] = __half2float(bxi->d); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } } @@ -2376,14 +2540,14 @@ template static __device__ __forceinline__ void loa constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_XXS, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_XXS)) / 2; constexpr int nrows = warp_size / threads_per_row; @@ -2414,22 +2578,22 @@ template static __device__ __forceinline__ void loa const int signs1 = __vcmpne4(((signs_packed & 0x30) << 3) | ((signs_packed & 0xC0) << 17), 0x00000000); const int grid1 = __vsub4(grid_pos[1] ^ signs1, signs1); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid0; x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid1; #else x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid0; x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid1; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } const int ls = aux32 >> 28; const float d = bxi->d; -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/4; #else x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = (ls*d + d/2)/4; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } } @@ -2438,14 +2602,14 @@ template static __device__ __forceinline__ void loa constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); #else constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16; int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_XS)) / 2; constexpr int nrows = warp_size / threads_per_row; @@ -2472,24 +2636,24 @@ template static __device__ __forceinline__ void loa const int grid_l = __vsub4(grid_pos[0] ^ signs[0], signs[0]); const int grid_h = __vsub4(grid_pos[1] ^ signs[1], signs[1]); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l; x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h; #else x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l; x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } const int ls = bxi->scales[kqsx]; const float d = bxi->d; -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4; x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4; #else x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4; x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } } @@ -2498,15 +2662,14 @@ template static __device__ __forceinline__ void loa constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_S, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_S)) / 2; constexpr int nrows = warp_size / threads_per_row; const int kqsx = threadIdx.x % threads_per_row; @@ -2539,24 +2702,24 @@ template static __device__ __forceinline__ void loa const int grid_l = __vsub4(grid_pos[0] ^ signs0, signs0); const int grid_h = __vsub4(grid_pos[1] ^ signs1, signs1); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l; x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h; #else x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l; x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } const int ls = bxi->scales[kqsx]; const float d = bxi->d; -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4; x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4; #else x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4; x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } } @@ -2565,14 +2728,14 @@ template static __device__ __forceinline__ void loa constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_XXS, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR3_XXS)) / 2; constexpr int nrows = warp_size / threads_per_row; @@ -2601,22 +2764,22 @@ template static __device__ __forceinline__ void loa const int grid_l = __vsub4(grid_pos.x ^ signs[0], signs[0]); const int grid_h = __vsub4(grid_pos.y ^ signs[1], signs[1]); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid_l; x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid_h; #else x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l; x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } const int ls = aux32 >> 28; const float d = bxi->d; -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/2; #else x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = (ls*d + d/2)/2; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } } @@ -2625,14 +2788,14 @@ template static __device__ __forceinline__ void loa constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR3_S)) / 2; constexpr int nrows = warp_size / threads_per_row; @@ -2668,22 +2831,22 @@ template static __device__ __forceinline__ void loa const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0); const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+0)] = grid_l; x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+1)] = grid_h; #else x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+0)] = grid_l; x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+1)] = grid_h; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } const int ls = 1 + 2*((bxi->scales[kqsx/2] >> (((2*kqsx) << 1) & 0x04)) & 0x0F); const float d = bxi->d; -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = ls*d; #else x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = ls*d; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } } @@ -2692,14 +2855,14 @@ template static __device__ __forceinline__ void loa constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) int * x_qs = (int *) x_tile; half2 * x_ds = (half2 *) (x_qs + MMQ_TILE_NE_K*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y); int * x_qs = (int *) x_tile; half2 * x_ds = (half2 *) (x_qs + txs.qs); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) constexpr int threads_per_row = MMQ_ITER_K / (4 * QR1_S); constexpr int nrows = warp_size / threads_per_row; @@ -2727,23 +2890,23 @@ template static __device__ __forceinline__ void loa const int grid0 = (grid >> 0) & 0x0F0F0F0F; const int grid1 = (grid >> 4) & 0x0F0F0F0F; -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+0)] = grid0; x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+1)] = grid1; #else x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+0)] = grid0; x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+1)] = grid1; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } const float d1q = __half2float(bxi->d) * (((qh >> 11) & 0x0E) + 1); const float delta = -1.0f + IQ1S_DELTA - (qh & 0x8000) * (2.0f*IQ1S_DELTA/0x8000); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_ds[i*MMQ_MMA_TILE_X_K_Q8_1 + kqsx] = make_half2(d1q, d1q*delta); #else x_ds[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = make_half2(d1q, d1q*delta); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } } @@ -2752,14 +2915,14 @@ template static __device__ __forceinline__ void loa constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_XS, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_XS); constexpr int nrows = warp_size / threads_per_row; @@ -2779,13 +2942,13 @@ template static __device__ __forceinline__ void loa const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl); const int k0 = 8 * (kqsx / 4) + kqsx % 4; -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x; x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y; #else x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0] = v.x; x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 4] = v.y; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } constexpr int rows_per_warp = warp_size / 8; @@ -2804,11 +2967,11 @@ template static __device__ __forceinline__ void loa const int ls = ((bxi->scales_l[(threadIdx.x % 8)/2] >> (4*(threadIdx.x % 2))) & 0x0F) | (((bxi->scales_h >> (2*(threadIdx.x % 8))) & 0x03) << 4); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x % 8] = d * (ls - 32); #else x_df[i*(MMQ_TILE_NE_K/4) + i/4 + threadIdx.x % 8] = d * (ls - 32); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } } @@ -2848,7 +3011,7 @@ static __device__ __forceinline__ void mmq_write_back_mma( constexpr int granularity = mmq_get_granularity_device(mmq_x); constexpr int nwarps = mmq_get_nwarps_device(); -#if defined(AMD_MFMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) constexpr int tileC_IJ = mmq_get_granularity_device(0); typedef tile tile_C; constexpr int rows_per_warp = granularity; @@ -2859,7 +3022,7 @@ static __device__ __forceinline__ void mmq_write_back_mma( constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. const int i0 = (threadIdx.y / ntx) * (ntx*tile_C::I); -#if defined(TURING_MMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) +#if defined(TURING_MMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) static_assert(nwarps*tile_C::I == mmq_y, "nwarps*tile_C::I != mmq_y"); #else GGML_UNUSED(nwarps); @@ -3063,13 +3226,13 @@ static __device__ __forceinline__ void mul_mat_q_process_tile( int * tile_y = data_mul_mat_q + mmq_x; int * tile_x = tile_y + GGML_PAD(mmq_x*MMQ_TILE_Y_K, nwarps*warp_size); -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) constexpr vec_dot_mmq_t vec_dot = mmq_type_traits::vec_dot_mma; constexpr mmq_write_back_t write_back = mmq_write_back_mma; #else constexpr vec_dot_mmq_t vec_dot = mmq_type_traits::vec_dot_dp4a; constexpr mmq_write_back_t write_back = mmq_write_back_dp4a; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) constexpr int blocks_per_iter = MMQ_ITER_K / qk; diff --git a/ggml/src/ggml-hip/CMakeLists.txt b/ggml/src/ggml-hip/CMakeLists.txt index 23b6889919f20..c22379bce492d 100644 --- a/ggml/src/ggml-hip/CMakeLists.txt +++ b/ggml/src/ggml-hip/CMakeLists.txt @@ -116,6 +116,10 @@ if (NOT GGML_HIP_MMQ_MFMA) add_compile_definitions(GGML_HIP_NO_MMQ_MFMA) endif() +if (NOT GGML_HIP_MMQ_WMMA) + add_compile_definitions(GGML_HIP_NO_MMQ_WMMA) +endif() + if (GGML_HIP_EXPORT_METRICS) set(CMAKE_HIP_FLAGS "${CMAKE_HIP_FLAGS} -Rpass-analysis=kernel-resource-usage --save-temps") endif()