From 5d1438624dbd7ac3081a4dcc05199da297464baa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Wed, 29 Oct 2025 00:49:09 +0100 Subject: [PATCH 1/3] CUDA: Volta tensor core support for MMF --- ggml/src/ggml-cuda/common.cuh | 10 +- ggml/src/ggml-cuda/mma.cuh | 172 ++++++++++++++++++++++++++++++---- ggml/src/ggml-cuda/mmf.cu | 2 +- ggml/src/ggml-cuda/mmf.cuh | 98 +++++++++++++++++-- 4 files changed, 257 insertions(+), 25 deletions(-) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 1af23588301dd..54ad4413924a7 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -224,6 +224,11 @@ static const char * cu_get_error_str(CUresult err) { #define AMD_MFMA_AVAILABLE #endif // defined(GGML_USE_HIP) && defined(CDNA) && !defined(GGML_HIP_NO_MMQ_MFMA) +// The Volta instructions are in principle available on Turing or newer but they are effectively unusable: +#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA +#define VOLTA_MMA_AVAILABLE +#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA + #if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING #define TURING_MMA_AVAILABLE #endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING @@ -278,7 +283,10 @@ static bool amd_mfma_available(const int cc) { #endif //!defined(GGML_HIP_NO_MMQ_MFMA) } -// Volta technically had FP16 tensor cores but they work very differently compared to Turing and later. +static bool volta_mma_available(const int cc) { + return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_VOLTA; +} + static bool turing_mma_available(const int cc) { return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_TURING; } diff --git a/ggml/src/ggml-cuda/mma.cuh b/ggml/src/ggml-cuda/mma.cuh index c1f24243fe388..4dad11cf5d82b 100644 --- a/ggml/src/ggml-cuda/mma.cuh +++ b/ggml/src/ggml-cuda/mma.cuh @@ -18,6 +18,10 @@ #include "common.cuh" +// On Volta each warp is doing 4 8x8 mma operations in parallel. +// The basic memory layout for a 32x8 output tile is to stack 4 input tiles in I direction and to mirror the B tile. +// However, the i indices in this file are by default permuted to simplify the index calculations. +// #define GGML_CUDA_MMA_NO_VOLTA_PERM #if CUDART_VERSION >= 11080 @@ -86,6 +90,7 @@ namespace ggml_cuda_mma { return 4 * (threadIdx.x / 32) + 8 * (l / 4) + (l % 4); } else { static_assert(I == -1 && J == -1, "template specialization not implemented"); + return -1; } } @@ -102,6 +107,32 @@ namespace ggml_cuda_mma { return threadIdx.x % 32; } else { static_assert(I == -1 && J == -1, "template specialization not implemented"); + return -1; + } + } +#elif __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA + static constexpr int ne = I * J / 32; + T x[ne] = {0}; + + static __device__ __forceinline__ int get_i(const int l) { + if constexpr (I == 32 && J == 8) { +#ifdef GGML_CUDA_MMA_NO_VOLTA_PERM + return (((threadIdx.x % 16) / 4) * 8) | ((threadIdx.x / 16) * 4) | (l & 2) | (threadIdx.x % 2); +#else + return (l & 2) | (threadIdx.x & ~2); +#endif // GGML_CUDA_MMA_NO_VOLTA_PERM + } else { + static_assert(I == -1 && J == -1, "template specialization not implemented"); + return -1; + } + } + + static __device__ __forceinline__ int get_j(const int l) { + if constexpr (I == 32 && J == 8) { + return (threadIdx.x & 2) | (l & (4 + 1)); + } else { + static_assert(I == -1 && J == -1, "template specialization not implemented"); + return -1; } } #else @@ -111,12 +142,13 @@ namespace ggml_cuda_mma { static __device__ __forceinline__ int get_i(const int l) { if constexpr (I == 8 && (J == 4 || J == 8)) { return threadIdx.x / 4; - } else if constexpr (I == 16 && J == 8) { - return (l / 2) * 8 + threadIdx.x / 4; + } else if constexpr ((I == 16 || I == 32) && J == 8) { + return ((l / 2) * 8) | (threadIdx.x / 4); } else if constexpr (I == 16 && J == 16) { - return ((l / 2) % 2) * 8 + threadIdx.x / 4; + return (((l / 2) % 2) * 8) | (threadIdx.x / 4); } else { static_assert(I == -1 && J == -1, "template specialization not implemented"); + return -1; } } @@ -124,13 +156,14 @@ namespace ggml_cuda_mma { if constexpr (I == 8 && J == 4) { return threadIdx.x % 4; } else if constexpr (I == 8 && J == 8) { - return 4 * l + threadIdx.x % 4; - } else if constexpr (I == 16 && J == 8) { - return 2 * (threadIdx.x % 4) + l % 2; + return (l * 4) | (threadIdx.x % 4); + } else if constexpr ((I == 16 || I == 32) && J == 8) { + return ((threadIdx.x % 4) * 2) | (l % 2); } else if constexpr (I == 16 && J == 16) { - return 8 * (l / 4) + 2 * (threadIdx.x % 4) + l % 2; + return ((l / 4) * 8) | ((threadIdx.x % 4) * 2) | (l % 2); } else { static_assert(I == -1 && J == -1, "template specialization not implemented"); + return -1; } } #endif // defined(GGML_USE_HIP) @@ -140,6 +173,35 @@ namespace ggml_cuda_mma { struct tile { static constexpr int I = I_; static constexpr int J = J_; + +#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA + static constexpr int ne = I == 8 && J == 8 ? I * J / (WARP_SIZE/4) : I * J / WARP_SIZE; + half2 x[ne] = {{0.0f, 0.0f}}; + + static __device__ __forceinline__ int get_i(const int l) { + if constexpr (I == 8 && J == 8) { + return ((threadIdx.x / 16) * 4) | (threadIdx.x % 4); + } else if constexpr (I == 32 && J == 8) { +#ifdef GGML_CUDA_MMA_NO_VOLTA_PERM + return (((threadIdx.x % 16) / 4) * 8) | ((threadIdx.x / 16) * 4) | (threadIdx.x % 4); +#else + return threadIdx.x; +#endif // GGML_CUDA_MMA_NO_VOLTA_PERM + } else { + static_assert(I == -1 && J == -1, "template specialization not implemented"); + return -1; + } + } + + static __device__ __forceinline__ int get_j(const int l) { + if constexpr ((I == 8 || I == 32) && J == 8) { + return l; + } else { + static_assert(I == -1 && J == -1, "template specialization not implemented"); + return -1; + } + } +#else static constexpr int ne = I * J / WARP_SIZE; half2 x[ne] = {{0.0f, 0.0f}}; @@ -147,25 +209,32 @@ namespace ggml_cuda_mma { if constexpr (I == 8 && J == 8) { return threadIdx.x / 4; } else if constexpr (I == 16 && J == 4) { - return l * 8 + threadIdx.x / 4; + return (l * 8) | (threadIdx.x / 4); } else if constexpr (I == 16 && J == 8) { - return (l % 2) * 8 + threadIdx.x / 4; + return ((l % 2) * 8) | (threadIdx.x / 4); + } else if constexpr (I == 32 && J == 8) { + return ((l / 4) * 16) | ((l % 2) * 8) | (threadIdx.x / 4); } else { static_assert(I == -1 && J == -1, "template specialization not implemented"); + return -1; } } static __device__ __forceinline__ int get_j(const int l) { if constexpr (I == 8 && J == 8) { - return l * 4 + threadIdx.x % 4; + return (l * 4) | (threadIdx.x % 4); } else if constexpr (I == 16 && J == 4) { return threadIdx.x % 4; } else if constexpr (I == 16 && J == 8) { - return (l / 2) * 4 + threadIdx.x % 4; + return ((l / 2) * 4) | (threadIdx.x % 4); + } else if constexpr (I == 32 && J == 8) { + return ((l & 2) * 2) | (threadIdx.x % 4); } else { static_assert(I == -1 && J == -1, "template specialization not implemented"); + return -1; } } +#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA }; template @@ -179,23 +248,25 @@ namespace ggml_cuda_mma { if constexpr (I == 8 && J == 8) { return threadIdx.x / 4; } else if constexpr (I == 16 && J == 4) { - return l * 8 + threadIdx.x / 4; + return (l * 8) | (threadIdx.x / 4); } else if constexpr (I == 16 && J == 8) { - return (l % 2) * 8 + threadIdx.x / 4; + return ((l % 2) * 8) | (threadIdx.x / 4); } else { static_assert(I == -1 && J == -1, "template specialization not implemented"); + return -1; } } static __device__ __forceinline__ int get_j(const int l) { if constexpr (I == 8 && J == 8) { - return l * 4 + threadIdx.x % 4; + return (l * 4) | (threadIdx.x % 4); } else if constexpr (I == 16 && J == 4) { return threadIdx.x % 4; } else if constexpr (I == 16 && J == 8) { - return (l / 2) * 4 + threadIdx.x % 4; + return ((l / 2) * 4) | (threadIdx.x % 4); } else { static_assert(I == -1 && J == -1, "template specialization not implemented"); + return -1; } } }; @@ -263,8 +334,12 @@ namespace ggml_cuda_mma { : "=r"(xi[0]), "=r"(xi[1]) : "l"(xs)); #else - load_generic(xs0, stride); - GGML_UNUSED(t); +#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA + GGML_UNUSED_VARS(t, xs0, stride); + NO_DEVICE_CODE; +#else + load_generic(t, xs0, stride); +#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA #endif // TURING_MMA_AVAILABLE } @@ -277,11 +352,35 @@ namespace ggml_cuda_mma { asm volatile("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];" : "=r"(xi[0]), "=r"(xi[1]), "=r"(xi[2]), "=r"(xi[3]) : "l"(xs)); +#else +#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA + GGML_UNUSED_VARS(t, xs0, stride); + NO_DEVICE_CODE; #else load_generic(t, xs0, stride); +#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA #endif // TURING_MMA_AVAILABLE } + template + static __device__ __forceinline__ void load_ldmatrix( + tile<32, 8, T> & t, const T * __restrict__ xs0, const int stride) { +#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA +#if 1 + // TODO: more generic handling + static_assert(sizeof(T) == 4, "bad type size"); + ggml_cuda_memcpy_1<4*sizeof(T)>(t.x + 0, xs0 + t.get_i(0)*stride + 0); + ggml_cuda_memcpy_1<4*sizeof(T)>(t.x + 4, xs0 + t.get_i(4)*stride + 4); +#else + load_generic(t, xs0, stride); +#endif // 1 +#else + tile<16, 8, T> * t16 = (tile<16, 8, T> *) &t; + load_ldmatrix(t16[0], xs0 + 0*stride, stride); + load_ldmatrix(t16[1], xs0 + 16*stride, stride); +#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA + } + template static __device__ __forceinline__ void load_ldmatrix_trans( tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) { @@ -546,4 +645,43 @@ namespace ggml_cuda_mma { NO_DEVICE_CODE; #endif // AMD_MFMA_AVAILABLE } + + template + static __device__ __forceinline__ void mma( + tile<32, J, T1> & D, const tile<32, K, T2> & A, const tile & B) { + tile<16, J, T1> * D16 = (tile<16, J, T1> *) &D; + tile<16, K, T2> * A16 = (tile<16, K, T2> *) &A; + mma(D16[0], A16[0], B); + mma(D16[1], A16[1], B); + } + + static __device__ __forceinline__ void mma( + tile<32, 8, float> & D, const tile<32, 8, half2> & A, const tile<8, 8, half2> & B) { +#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA + const int * Axi = (const int *) A.x; + const int * Bxi = (const int *) B.x; + int * Dxi = (int *) D.x; + asm("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3, %4, %5, %6, %7}, {%8, %9}, {%10, %11}, {%0, %1, %2, %3, %4, %5, %6, %7};" + : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]), "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7]) + : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]), "r"(Bxi[1])); + asm("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3, %4, %5, %6, %7}, {%8, %9}, {%10, %11}, {%0, %1, %2, %3, %4, %5, %6, %7};" + : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]), "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7]) + : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[2]), "r"(Bxi[3])); + asm("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3, %4, %5, %6, %7}, {%8, %9}, {%10, %11}, {%0, %1, %2, %3, %4, %5, %6, %7};" + : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]), "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7]) + : "r"(Axi[4]), "r"(Axi[5]), "r"(Bxi[4]), "r"(Bxi[5])); + asm("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3, %4, %5, %6, %7}, {%8, %9}, {%10, %11}, {%0, %1, %2, %3, %4, %5, %6, %7};" + : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]), "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7]) + : "r"(Axi[6]), "r"(Axi[7]), "r"(Bxi[6]), "r"(Bxi[7])); +#else + tile<16, 8, float> * D16 = (tile<16, 8, float> *) &D; + tile<16, 8, half2> * A16 = (tile<16, 8, half2> *) &A; + mma(D16[0], A16[0], B); + mma(D16[1], A16[1], B); +#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE + } } diff --git a/ggml/src/ggml-cuda/mmf.cu b/ggml/src/ggml-cuda/mmf.cu index 9e2aaf52d6cce..2b0a61395b458 100644 --- a/ggml/src/ggml-cuda/mmf.cu +++ b/ggml/src/ggml-cuda/mmf.cu @@ -148,7 +148,7 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const case GGML_TYPE_F32: return ampere_mma_available(cc); case GGML_TYPE_F16: - return turing_mma_available(cc); + return volta_mma_available(cc) || turing_mma_available(cc); case GGML_TYPE_BF16: return ampere_mma_available(cc); default: diff --git a/ggml/src/ggml-cuda/mmf.cuh b/ggml/src/ggml-cuda/mmf.cuh index 49d5295be0ea0..f4bad0c8f2c63 100644 --- a/ggml/src/ggml-cuda/mmf.cuh +++ b/ggml/src/ggml-cuda/mmf.cuh @@ -20,17 +20,23 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * scr0_ne, const int src1_ncols, bool mul_mat_id); template -__launch_bounds__(ggml_cuda_get_physical_warp_size()*nwarps, 1) -static __global__ void mul_mat_f( +static __device__ void mul_mat_f_impl( const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst, const int ncols, const int ncols_dst_total, const int nchannels_dst, const int stride_row, const int stride_col_y, const int stride_col_dst, const int stride_col_id, const int stride_row_id, const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) { #if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) +#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA + typedef tile<32, 8, T> tile_A; + typedef tile< 8, 8, T> tile_B; + typedef tile<32, 8, float> tile_C; +#else + // In principle also possible to use tiles with I == 32, the performance difference is ~1%. typedef tile<16, 8, T> tile_A; typedef tile< 8, 8, T> tile_B; typedef tile<16, 8, float> tile_C; +#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA constexpr int warp_size = ggml_cuda_get_physical_warp_size(); constexpr int tile_k_padded = warp_size + 4; @@ -232,11 +238,43 @@ static __global__ void mul_mat_f( #endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) } +template +__launch_bounds__(ggml_cuda_get_physical_warp_size()*nwarps, 1) +static __global__ void mul_mat_f( + const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst, + const int ncols, const int ncols_dst_total, const int nchannels_dst, const int stride_row, const int stride_col_y, const int stride_col_dst, + const int stride_col_id, const int stride_row_id, + const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, + const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) { + if constexpr (std::is_same_v) { +#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + mul_mat_f_impl( + x, y, ids, dst, ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y, + stride_col_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, + stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); +#else + NO_DEVICE_CODE; +#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + } else if constexpr (std::is_same_v || std::is_same_v) { +#ifdef AMPERE_MMA_AVAILABLE + mul_mat_f_impl( + x, y, ids, dst, ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y, + stride_col_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, + stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); +#else + NO_DEVICE_CODE; +#endif // AMPERE_MMA_AVAILABLE + } else { + static_assert(std::is_same_v, "bad type"); + } + GGML_UNUSED_VARS(x, y, ids, dst, ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y, + stride_col_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, + stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); +} //This kernel is for larger batch sizes of mul_mat_id template -__launch_bounds__(ggml_cuda_get_physical_warp_size()*nwarps, 1) -static __global__ void mul_mat_f_ids( +static __device__ void mul_mat_f_ids_impl( const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids_src_compact, const int32_t * __restrict__ ids_dst_compact, const int32_t * __restrict__ expert_bounds, float * __restrict__ dst, @@ -245,9 +283,16 @@ static __global__ void mul_mat_f_ids( const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst, const uint3 sis1_fd, const uint3 nch_fd) { #if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) +#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA + typedef tile<32, 8, T> tile_A; + typedef tile< 8, 8, T> tile_B; + typedef tile<32, 8, float> tile_C; +#else + // In principle also possible to use tiles with I == 32, the performance difference is ~1%. typedef tile<16, 8, T> tile_A; typedef tile< 8, 8, T> tile_B; typedef tile<16, 8, float> tile_C; +#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA constexpr int warp_size = ggml_cuda_get_physical_warp_size(); constexpr int tile_k_padded = warp_size + 4; @@ -476,6 +521,46 @@ static __global__ void mul_mat_f_ids( #endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) } +template +__launch_bounds__(ggml_cuda_get_physical_warp_size()*nwarps, 1) +static __global__ void mul_mat_f_ids( + const T * __restrict__ x, const float * __restrict__ y, + const int32_t * __restrict__ ids_src_compact, const int32_t * __restrict__ ids_dst_compact, + const int32_t * __restrict__ expert_bounds, float * __restrict__ dst, + const int ncols, const int ncols_dst_total, const int nchannels_dst, const int stride_row, const int stride_col_y, const int stride_col_dst, + const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, + const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst, + const uint3 sis1_fd, const uint3 nch_fd) { + if constexpr (std::is_same_v) { +#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + mul_mat_f_ids_impl( + x, y, ids_src_compact, ids_dst_compact, expert_bounds, dst, + ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y, stride_col_dst, + channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, sis1_fd, nch_fd); +#else + NO_DEVICE_CODE; +#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + } else if constexpr (std::is_same_v || std::is_same_v) { +#ifdef AMPERE_MMA_AVAILABLE + mul_mat_f_ids_impl( + x, y, ids_src_compact, ids_dst_compact, expert_bounds, dst, + ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y, stride_col_dst, + channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, sis1_fd, nch_fd); +#else + NO_DEVICE_CODE; +#endif // AMPERE_MMA_AVAILABLE + } else { + static_assert(std::is_same_v, "bad type"); + } + GGML_UNUSED_VARS( + x, y, ids_src_compact, ids_dst_compact, expert_bounds, dst, + ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y, stride_col_dst, + channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, sis1_fd, nch_fd); +} + template static inline void mul_mat_f_switch_ids( const T * x, const float * y, const int32_t * ids, float * dst, @@ -533,7 +618,8 @@ void mul_mat_f_cuda( const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x, const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, cudaStream_t stream, const mmf_ids_data * ids_data) { - typedef tile<16, 8, T> tile_A; + typedef tile<32, 8, T> tile_A_16; + typedef tile<32, 8, T> tile_A_32; typedef tile< 8, 8, T> tile_B; GGML_ASSERT(ncols_x % 2 == 0); @@ -559,7 +645,7 @@ void mul_mat_f_cuda( } constexpr int rows_per_block = MMF_ROWS_PER_BLOCK; - const int nbytes_shared_iter = nwarps_best * tile_A::I * (warp_size + 4) * 4; + const int nbytes_shared_iter = nwarps_best * (volta_mma_available ? tile_A_32::I : tile_A_16::I) * (warp_size + 4) * 4; const int nbytes_shared_combine = GGML_PAD(cols_per_block, tile_B::I) * (nwarps_best*rows_per_block + 4) * 4; const int nbytes_shared = std::max(nbytes_shared_iter, nbytes_shared_combine); const int nbytes_slotmap = ids ? GGML_PAD(cols_per_block, 16) * sizeof(int) : 0; From 7efb6acfe6cb552af17cb9555f90f378b1587c97 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Thu, 30 Oct 2025 15:39:08 +0100 Subject: [PATCH 2/3] more generic checks for hardware support --- ggml/src/ggml-cuda/mma.cuh | 81 +++++++++++++++++----- ggml/src/ggml-cuda/mmf.cuh | 133 ++++++++++--------------------------- 2 files changed, 100 insertions(+), 114 deletions(-) diff --git a/ggml/src/ggml-cuda/mma.cuh b/ggml/src/ggml-cuda/mma.cuh index 4dad11cf5d82b..a7a28fd1ae660 100644 --- a/ggml/src/ggml-cuda/mma.cuh +++ b/ggml/src/ggml-cuda/mma.cuh @@ -77,6 +77,15 @@ namespace ggml_cuda_mma { static constexpr int ne = I * J / 64; T x[ne] = {0}; + static constexpr __device__ bool supported() { + if (I == 64 && J == 2) return true; + if (I == 16 && J == 8) return true; + if (I == 32 && J == 4) return true; + if (I == 16 && J == 16) return true; + if (I == 32 && J == 32) return true; + return false; + } + static __device__ __forceinline__ int get_i(const int l) { if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8> return threadIdx.x % 16; @@ -89,7 +98,7 @@ namespace ggml_cuda_mma { } else if constexpr (I == 32 && J == 32) { return 4 * (threadIdx.x / 32) + 8 * (l / 4) + (l % 4); } else { - static_assert(I == -1 && J == -1, "template specialization not implemented"); + NO_DEVICE_CODE; return -1; } } @@ -106,7 +115,7 @@ namespace ggml_cuda_mma { } else if constexpr (I == 32 && J == 32) { return threadIdx.x % 32; } else { - static_assert(I == -1 && J == -1, "template specialization not implemented"); + NO_DEVICE_CODE; return -1; } } @@ -114,6 +123,11 @@ namespace ggml_cuda_mma { static constexpr int ne = I * J / 32; T x[ne] = {0}; + static constexpr __device__ bool supported() { + if (I == 32 && J == 8) return true; + return false; + } + static __device__ __forceinline__ int get_i(const int l) { if constexpr (I == 32 && J == 8) { #ifdef GGML_CUDA_MMA_NO_VOLTA_PERM @@ -122,7 +136,7 @@ namespace ggml_cuda_mma { return (l & 2) | (threadIdx.x & ~2); #endif // GGML_CUDA_MMA_NO_VOLTA_PERM } else { - static_assert(I == -1 && J == -1, "template specialization not implemented"); + NO_DEVICE_CODE; return -1; } } @@ -131,7 +145,7 @@ namespace ggml_cuda_mma { if constexpr (I == 32 && J == 8) { return (threadIdx.x & 2) | (l & (4 + 1)); } else { - static_assert(I == -1 && J == -1, "template specialization not implemented"); + NO_DEVICE_CODE; return -1; } } @@ -139,15 +153,28 @@ namespace ggml_cuda_mma { static constexpr int ne = I * J / 32; T x[ne] = {0}; + static constexpr __device__ bool supported() { + if (I == 8 && J == 4) return true; + if (I == 8 && J == 8) return true; + if (I == 16 && J == 8) return true; + if (I == 16 && J == 16) return true; + if (I == 32 && J == 8) return true; + return false; + } + static __device__ __forceinline__ int get_i(const int l) { - if constexpr (I == 8 && (J == 4 || J == 8)) { + if constexpr (I == 8 && J == 4) { return threadIdx.x / 4; - } else if constexpr ((I == 16 || I == 32) && J == 8) { + } else if constexpr (I == 8 && J == 8) { + return threadIdx.x / 4; + } else if constexpr (I == 16 && J == 8) { return ((l / 2) * 8) | (threadIdx.x / 4); } else if constexpr (I == 16 && J == 16) { return (((l / 2) % 2) * 8) | (threadIdx.x / 4); + } else if constexpr (I == 32 && J == 8) { + return tile<16, 8, T>::get_i(l); // Memory layout simply repeated with same pattern in i direction. } else { - static_assert(I == -1 && J == -1, "template specialization not implemented"); + NO_DEVICE_CODE; return -1; } } @@ -157,12 +184,14 @@ namespace ggml_cuda_mma { return threadIdx.x % 4; } else if constexpr (I == 8 && J == 8) { return (l * 4) | (threadIdx.x % 4); - } else if constexpr ((I == 16 || I == 32) && J == 8) { + } else if constexpr (I == 16 && J == 8) { return ((threadIdx.x % 4) * 2) | (l % 2); } else if constexpr (I == 16 && J == 16) { return ((l / 4) * 8) | ((threadIdx.x % 4) * 2) | (l % 2); + } else if constexpr (I == 32 && J == 8) { + return tile<16, 8, T>::get_j(l); // Memory layout simply repeated with same pattern in i direction. } else { - static_assert(I == -1 && J == -1, "template specialization not implemented"); + NO_DEVICE_CODE; return -1; } } @@ -178,6 +207,12 @@ namespace ggml_cuda_mma { static constexpr int ne = I == 8 && J == 8 ? I * J / (WARP_SIZE/4) : I * J / WARP_SIZE; half2 x[ne] = {{0.0f, 0.0f}}; + static constexpr __device__ bool supported() { + if (I == 8 && J == 8) return true; + if (I == 32 && J == 8) return true; + return false; + } + static __device__ __forceinline__ int get_i(const int l) { if constexpr (I == 8 && J == 8) { return ((threadIdx.x / 16) * 4) | (threadIdx.x % 4); @@ -188,7 +223,7 @@ namespace ggml_cuda_mma { return threadIdx.x; #endif // GGML_CUDA_MMA_NO_VOLTA_PERM } else { - static_assert(I == -1 && J == -1, "template specialization not implemented"); + NO_DEVICE_CODE; return -1; } } @@ -197,7 +232,7 @@ namespace ggml_cuda_mma { if constexpr ((I == 8 || I == 32) && J == 8) { return l; } else { - static_assert(I == -1 && J == -1, "template specialization not implemented"); + NO_DEVICE_CODE; return -1; } } @@ -205,6 +240,15 @@ namespace ggml_cuda_mma { static constexpr int ne = I * J / WARP_SIZE; half2 x[ne] = {{0.0f, 0.0f}}; + static constexpr __device__ bool supported() { + if (I == 8 && J == 4) return true; + if (I == 8 && J == 8) return true; + if (I == 16 && J == 8) return true; + if (I == 16 && J == 16) return true; + if (I == 32 && J == 8) return true; + return false; + } + static __device__ __forceinline__ int get_i(const int l) { if constexpr (I == 8 && J == 8) { return threadIdx.x / 4; @@ -215,7 +259,7 @@ namespace ggml_cuda_mma { } else if constexpr (I == 32 && J == 8) { return ((l / 4) * 16) | ((l % 2) * 8) | (threadIdx.x / 4); } else { - static_assert(I == -1 && J == -1, "template specialization not implemented"); + NO_DEVICE_CODE; return -1; } } @@ -230,7 +274,7 @@ namespace ggml_cuda_mma { } else if constexpr (I == 32 && J == 8) { return ((l & 2) * 2) | (threadIdx.x % 4); } else { - static_assert(I == -1 && J == -1, "template specialization not implemented"); + NO_DEVICE_CODE; return -1; } } @@ -244,6 +288,13 @@ namespace ggml_cuda_mma { static constexpr int ne = I * J / WARP_SIZE; nv_bfloat162 x[ne] = {{0.0f, 0.0f}}; + static constexpr __device__ bool supported() { + if (I == 8 && J == 8) return true; + if (I == 16 && J == 4) return true; + if (I == 16 && J == 8) return true; + return false; + } + static __device__ __forceinline__ int get_i(const int l) { if constexpr (I == 8 && J == 8) { return threadIdx.x / 4; @@ -252,7 +303,7 @@ namespace ggml_cuda_mma { } else if constexpr (I == 16 && J == 8) { return ((l % 2) * 8) | (threadIdx.x / 4); } else { - static_assert(I == -1 && J == -1, "template specialization not implemented"); + NO_DEVICE_CODE; return -1; } } @@ -265,7 +316,7 @@ namespace ggml_cuda_mma { } else if constexpr (I == 16 && J == 8) { return ((l / 2) * 4) | (threadIdx.x % 4); } else { - static_assert(I == -1 && J == -1, "template specialization not implemented"); + NO_DEVICE_CODE; return -1; } } diff --git a/ggml/src/ggml-cuda/mmf.cuh b/ggml/src/ggml-cuda/mmf.cuh index f4bad0c8f2c63..d6985a2a687b3 100644 --- a/ggml/src/ggml-cuda/mmf.cuh +++ b/ggml/src/ggml-cuda/mmf.cuh @@ -20,23 +20,27 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * scr0_ne, const int src1_ncols, bool mul_mat_id); template -static __device__ void mul_mat_f_impl( +__launch_bounds__(ggml_cuda_get_physical_warp_size()*nwarps, 1) +static __global__ void mul_mat_f( const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst, const int ncols, const int ncols_dst_total, const int nchannels_dst, const int stride_row, const int stride_col_y, const int stride_col_dst, const int stride_col_id, const int stride_row_id, const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) { #if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) -#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA - typedef tile<32, 8, T> tile_A; - typedef tile< 8, 8, T> tile_B; - typedef tile<32, 8, float> tile_C; -#else - // In principle also possible to use tiles with I == 32, the performance difference is ~1%. - typedef tile<16, 8, T> tile_A; - typedef tile< 8, 8, T> tile_B; - typedef tile<16, 8, float> tile_C; -#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA + constexpr bool I_16_supported = tile<16, 8, T>::supported() && tile<16, 8, float>::supported(); + constexpr bool I_32_supported = tile<32, 8, T>::supported() && tile<32, 8, float>::supported(); + + if (!I_16_supported && !I_32_supported) { + NO_DEVICE_CODE; + return; + } + + constexpr int I_preferred = I_16_supported ? 16 : 32; // For Turing MMA both work butr 16 is ~1% faster. + + typedef tile tile_A; + typedef tile<8, 8, T> tile_B; + typedef tile tile_C; constexpr int warp_size = ggml_cuda_get_physical_warp_size(); constexpr int tile_k_padded = warp_size + 4; @@ -238,43 +242,10 @@ static __device__ void mul_mat_f_impl( #endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) } -template -__launch_bounds__(ggml_cuda_get_physical_warp_size()*nwarps, 1) -static __global__ void mul_mat_f( - const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst, - const int ncols, const int ncols_dst_total, const int nchannels_dst, const int stride_row, const int stride_col_y, const int stride_col_dst, - const int stride_col_id, const int stride_row_id, - const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, - const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) { - if constexpr (std::is_same_v) { -#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - mul_mat_f_impl( - x, y, ids, dst, ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y, - stride_col_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, - stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); -#else - NO_DEVICE_CODE; -#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - } else if constexpr (std::is_same_v || std::is_same_v) { -#ifdef AMPERE_MMA_AVAILABLE - mul_mat_f_impl( - x, y, ids, dst, ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y, - stride_col_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, - stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); -#else - NO_DEVICE_CODE; -#endif // AMPERE_MMA_AVAILABLE - } else { - static_assert(std::is_same_v, "bad type"); - } - GGML_UNUSED_VARS(x, y, ids, dst, ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y, - stride_col_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, - stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); -} - //This kernel is for larger batch sizes of mul_mat_id template -static __device__ void mul_mat_f_ids_impl( +__launch_bounds__(ggml_cuda_get_physical_warp_size()*nwarps, 1) +static __global__ void mul_mat_f_ids( const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids_src_compact, const int32_t * __restrict__ ids_dst_compact, const int32_t * __restrict__ expert_bounds, float * __restrict__ dst, @@ -283,16 +254,19 @@ static __device__ void mul_mat_f_ids_impl( const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst, const uint3 sis1_fd, const uint3 nch_fd) { #if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) -#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA - typedef tile<32, 8, T> tile_A; - typedef tile< 8, 8, T> tile_B; - typedef tile<32, 8, float> tile_C; -#else - // In principle also possible to use tiles with I == 32, the performance difference is ~1%. - typedef tile<16, 8, T> tile_A; - typedef tile< 8, 8, T> tile_B; - typedef tile<16, 8, float> tile_C; -#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA + constexpr bool I_16_supported = tile<16, 8, T>::supported() && tile<16, 8, float>::supported(); + constexpr bool I_32_supported = tile<32, 8, T>::supported() && tile<32, 8, float>::supported(); + + if (!I_16_supported && !I_32_supported) { + NO_DEVICE_CODE; + return; + } + + constexpr int I_preferred = I_16_supported ? 16 : 32; // For Turing MMA both work butr 16 is ~1% faster. + + typedef tile tile_A; + typedef tile<8, 8, T> tile_B; + typedef tile tile_C; constexpr int warp_size = ggml_cuda_get_physical_warp_size(); constexpr int tile_k_padded = warp_size + 4; @@ -521,46 +495,6 @@ static __device__ void mul_mat_f_ids_impl( #endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) } -template -__launch_bounds__(ggml_cuda_get_physical_warp_size()*nwarps, 1) -static __global__ void mul_mat_f_ids( - const T * __restrict__ x, const float * __restrict__ y, - const int32_t * __restrict__ ids_src_compact, const int32_t * __restrict__ ids_dst_compact, - const int32_t * __restrict__ expert_bounds, float * __restrict__ dst, - const int ncols, const int ncols_dst_total, const int nchannels_dst, const int stride_row, const int stride_col_y, const int stride_col_dst, - const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, - const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst, - const uint3 sis1_fd, const uint3 nch_fd) { - if constexpr (std::is_same_v) { -#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - mul_mat_f_ids_impl( - x, y, ids_src_compact, ids_dst_compact, expert_bounds, dst, - ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, sis1_fd, nch_fd); -#else - NO_DEVICE_CODE; -#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) - } else if constexpr (std::is_same_v || std::is_same_v) { -#ifdef AMPERE_MMA_AVAILABLE - mul_mat_f_ids_impl( - x, y, ids_src_compact, ids_dst_compact, expert_bounds, dst, - ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, sis1_fd, nch_fd); -#else - NO_DEVICE_CODE; -#endif // AMPERE_MMA_AVAILABLE - } else { - static_assert(std::is_same_v, "bad type"); - } - GGML_UNUSED_VARS( - x, y, ids_src_compact, ids_dst_compact, expert_bounds, dst, - ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, sis1_fd, nch_fd); -} - template static inline void mul_mat_f_switch_ids( const T * x, const float * y, const int32_t * ids, float * dst, @@ -618,7 +552,7 @@ void mul_mat_f_cuda( const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x, const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, cudaStream_t stream, const mmf_ids_data * ids_data) { - typedef tile<32, 8, T> tile_A_16; + typedef tile<16, 8, T> tile_A_16; typedef tile<32, 8, T> tile_A_32; typedef tile< 8, 8, T> tile_B; @@ -630,7 +564,8 @@ void mul_mat_f_cuda( const int64_t channel_ratio = nchannels_dst / nchannels_x; const int64_t sample_ratio = nsamples_dst / nsamples_x; - const int device = ggml_cuda_get_device(); + const int device = ggml_cuda_get_device(); + const int cc = ggml_cuda_info().devices[device].cc; const int warp_size = ggml_cuda_info().devices[device].warp_size; int64_t nwarps_best = 1; @@ -645,7 +580,7 @@ void mul_mat_f_cuda( } constexpr int rows_per_block = MMF_ROWS_PER_BLOCK; - const int nbytes_shared_iter = nwarps_best * (volta_mma_available ? tile_A_32::I : tile_A_16::I) * (warp_size + 4) * 4; + const int nbytes_shared_iter = nwarps_best * (volta_mma_available(cc) ? tile_A_32::I : tile_A_16::I) * (warp_size + 4) * 4; const int nbytes_shared_combine = GGML_PAD(cols_per_block, tile_B::I) * (nwarps_best*rows_per_block + 4) * 4; const int nbytes_shared = std::max(nbytes_shared_iter, nbytes_shared_combine); const int nbytes_slotmap = ids ? GGML_PAD(cols_per_block, 16) * sizeof(int) : 0; From 53badac6c343bf9f6811c84337dacce9934b32ff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Fri, 31 Oct 2025 07:27:05 +0100 Subject: [PATCH 3/3] Update ggml/src/ggml-cuda/mmf.cuh Co-authored-by: Aman Gupta --- ggml/src/ggml-cuda/mmf.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-cuda/mmf.cuh b/ggml/src/ggml-cuda/mmf.cuh index d6985a2a687b3..f7e46e2f63b2f 100644 --- a/ggml/src/ggml-cuda/mmf.cuh +++ b/ggml/src/ggml-cuda/mmf.cuh @@ -36,7 +36,7 @@ static __global__ void mul_mat_f( return; } - constexpr int I_preferred = I_16_supported ? 16 : 32; // For Turing MMA both work butr 16 is ~1% faster. + constexpr int I_preferred = I_16_supported ? 16 : 32; // For Turing MMA both work but 16 is ~1% faster. typedef tile tile_A; typedef tile<8, 8, T> tile_B;