Skip to content
18 changes: 14 additions & 4 deletions ggml/src/ggml-cuda/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
#define GGML_CUDA_CC_GCN4 (GGML_CUDA_CC_OFFSET_AMD + 0x803) // Tonga, Fiji, Polaris, minimum for fast fp16
#define GGML_CUDA_CC_VEGA (GGML_CUDA_CC_OFFSET_AMD + 0x900) // Vega56/64, minimum for fp16 dual issue
#define GGML_CUDA_CC_VEGA20 (GGML_CUDA_CC_OFFSET_AMD + 0x906) // MI50/Radeon VII, minimum for dp4a
#define GGML_CUDA_CC_CDNA (GGML_CUDA_CC_OFFSET_AMD + 0x908) // MI100, minimum for MFMA, acc registers
#define GGML_CUDA_CC_CDNA1 (GGML_CUDA_CC_OFFSET_AMD + 0x908) // MI100, minimum for MFMA, acc registers
#define GGML_CUDA_CC_CDNA2 (GGML_CUDA_CC_OFFSET_AMD + 0x910) // MI210, minimum acc register renameing
#define GGML_CUDA_CC_CDNA3 (GGML_CUDA_CC_OFFSET_AMD + 0x942) // MI300

Expand All @@ -72,8 +72,9 @@
#define GGML_CUDA_CC_IS_RDNA2(cc) (cc >= GGML_CUDA_CC_RDNA2 && cc < GGML_CUDA_CC_RDNA3)
#define GGML_CUDA_CC_IS_RDNA3(cc) (cc >= GGML_CUDA_CC_RDNA3 && cc < GGML_CUDA_CC_RDNA4)
#define GGML_CUDA_CC_IS_RDNA4(cc) (cc >= GGML_CUDA_CC_RDNA4)
#define GGML_CUDA_CC_IS_GCN(cc) (cc > GGML_CUDA_CC_OFFSET_AMD && cc < GGML_CUDA_CC_CDNA)
#define GGML_CUDA_CC_IS_CDNA(cc) (cc >= GGML_CUDA_CC_CDNA && cc < GGML_CUDA_CC_RDNA1)
#define GGML_CUDA_CC_IS_GCN(cc) (cc > GGML_CUDA_CC_OFFSET_AMD && cc < GGML_CUDA_CC_CDNA1)
#define GGML_CUDA_CC_IS_CDNA(cc) (cc >= GGML_CUDA_CC_CDNA1 && cc < GGML_CUDA_CC_RDNA1)
#define GGML_CUDA_CC_IS_CDNA3(cc) (cc >= GGML_CUDA_CC_CDNA3 && cc < GGML_CUDA_CC_RDNA1)

// Moore Threads
#define GGML_CUDA_MUSA_ARCH_IS_QY1 (__MUSA_ARCH__ <= 210)
Expand Down Expand Up @@ -211,6 +212,10 @@ typedef float2 dfloat2;
#define FP16_MMA_AVAILABLE
#endif // defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || defined(RDNA4))

#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && defined(CDNA3)
#define AMD_MMA_AVAILABLE
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && defined(CDNA3)

#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
#define NEW_MMA_AVAILABLE
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
Expand Down Expand Up @@ -252,6 +257,11 @@ static bool fp16_mma_hardware_available(const int cc) {
GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc);
}

// AMD CDNA3 matrix cores.. Will add support for other CDNA generations later.
static bool amd_mma_available(const int cc) {
return cc >= GGML_CUDA_CC_OFFSET_AMD && GGML_CUDA_CC_IS_CDNA3(cc);
}

// Volta technically had FP16 tensor cores but they work very differently compared to Turing and later.
static bool new_mma_available(const int cc) {
return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_TURING;
Expand All @@ -261,7 +271,7 @@ static bool cp_async_available(const int cc) {
return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_AMPERE;
}

static constexpr __device__ int ggml_cuda_get_physical_warp_size() {
static constexpr __host__ __device__ int ggml_cuda_get_physical_warp_size() {
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
return __AMDGCN_WAVEFRONT_SIZE;
#else
Expand Down
136 changes: 116 additions & 20 deletions ggml/src/ggml-cuda/mma.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -64,34 +64,69 @@ namespace ggml_cuda_mma {

template <int I_, int J_, typename T>
struct tile {
static constexpr int warp_size = ggml_cuda_get_physical_warp_size();
static constexpr int I = I_;
static constexpr int J = J_;
static constexpr int ne = I * J / WARP_SIZE;
static constexpr int ne = I * J / warp_size;
T x[ne] = {0};

static __device__ __forceinline__ int get_i(const int l) {
if constexpr (I == 8 && (J == 4 || J == 8)) {
return threadIdx.x / 4;
} else if constexpr (I == 16 && J == 8) {
return (l / 2) * 8 + threadIdx.x / 4;
} else if constexpr (I == 16 && J == 16) {
return ((l / 2) % 2) * 8 + threadIdx.x / 4;
} else {
static_assert(I == -1 && J == -1, "template specialization not implemented");
if constexpr (warp_size == 32) {
if constexpr (I == 8 && (J == 4 || J == 8)) {
return threadIdx.x / 4;
} else if constexpr (I == 16 && J == 8) {
return (l / 2) * 8 + threadIdx.x / 4;
} else if constexpr (I == 16 && J == 16) {
return ((l / 2) % 2) * 8 + threadIdx.x / 4;
} else {
static_assert(I == -1 && J == -1, "template specialization not implemented");
}
} else if constexpr (warp_size == 64) {
if constexpr (I == 8 && (J == 4 || J == 8)) { // Remove this case
return threadIdx.x / 4;
} else if constexpr (I == 16 && J == 8) {
return threadIdx.x % 16;
} else if constexpr (I == 32 && J == 4) {
return threadIdx.x % 32;
} else if constexpr (I == 16 && J == 16) {
return 4 * (threadIdx.x / 16) + l;
} else if constexpr (I == 32 && J == 32) {
return 4 * (threadIdx.x / 32) + 8 * (l / 4) + (l % 4);
} else {
static_assert(I == -1 && J == -1, "template specialization not implemented");
}
}
}

static __device__ __forceinline__ int get_j(const int l) {
if constexpr (I == 8 && J == 4) {
return threadIdx.x % 4;
} else if constexpr (I == 8 && J == 8) {
return 4 * l + threadIdx.x % 4;
} else if constexpr (I == 16 && J == 8) {
return 2 * (threadIdx.x % 4) + l % 2;
} else if constexpr (I == 16 && J == 16) {
return 8 * (l / 4) + 2 * (threadIdx.x % 4) + l % 2;
} else {
static_assert(I == -1 && J == -1, "template specialization not implemented");
if constexpr (warp_size == 32) {
if constexpr (I == 8 && J == 4) {
return threadIdx.x % 4;
} else if constexpr (I == 8 && J == 8) {
return 4 * l + threadIdx.x % 4;
} else if constexpr (I == 16 && J == 8) {
return 2 * (threadIdx.x % 4) + l % 2;
} else if constexpr (I == 16 && J == 16) {
return 8 * (l / 4) + 2 * (threadIdx.x % 4) + l % 2;
} else {
static_assert(I == -1 && J == -1, "template specialization not implemented");
}
} else if constexpr (warp_size == 64) {
if constexpr (I == 8 && J == 4) { // Remove this case
return threadIdx.x % 4;
} else if constexpr (I == 8 && J == 8) { // Remove this case
return 4 * l + threadIdx.x % 4;
} else if constexpr (I == 16 && J == 8) {
return 2 * (threadIdx.x / 16) + l;
} else if constexpr (I == 32 && J == 4) {
return 2 * (threadIdx.x / 32) + l;
} else if constexpr (I == 16 && J == 16) {
return threadIdx.x % 16;
} else if constexpr (I == 32 && J == 32) {
return threadIdx.x % 32;
} else {
static_assert(I == -1 && J == -1, "template specialization not implemented");
}
}
}
};
Expand Down Expand Up @@ -186,7 +221,11 @@ namespace ggml_cuda_mma {
template <typename T>
static __device__ __forceinline__ void load_ldmatrix(
tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) {
#ifdef NEW_MMA_AVAILABLE
#if defined(AMD_MMA_AVAILABLE)
int64_t* xi = (int64_t*) t.x;
const int64_t* xs = (int64_t*) ((const int*) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I));
xi[0] = xs[0];
#elif defined(NEW_MMA_AVAILABLE)
int * xi = (int * ) t.x;
const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2);
asm volatile("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];"
Expand All @@ -197,6 +236,23 @@ namespace ggml_cuda_mma {
#endif // NEW_MMA_AVAILABLE
}

template <typename T>
static __device__ __forceinline__ void load_ldmatrix(
tile<32, 4, T> & t, const T * __restrict__ xs0, const int stride) {
#if defined(AMD_MMA_AVAILABLE)
int64_t* xi = (int64_t*) t.x;
const int64_t* xs = (int64_t*) ((const int*) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I));
xi[0] = xs[0];
#elif defined(NEW_MMA_AVAILABLE)
GGML_UNUSED(t);
GGML_UNUSED(xs0);
GGML_UNUSED(stride);
NO_DEVICE_CODE;
#else
load_generic(t, xs0, stride);
#endif // AMD_MMA_AVAILABLE
}

template <typename T>
static __device__ __forceinline__ void load_ldmatrix_trans(
tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) {
Expand Down Expand Up @@ -386,6 +442,46 @@ namespace ggml_cuda_mma {
: "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
: "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[3]));
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
#else
GGML_UNUSED(D);
GGML_UNUSED(A);
GGML_UNUSED(B);
NO_DEVICE_CODE;
#endif // NEW_MMA_AVAILABLE
}

static __device__ __forceinline__ void mma(
tile<16, 16, int> & D, const tile<16, 8, int> & A, const tile<16, 8, int> & B) {
#if defined(AMD_MMA_AVAILABLE)
#if defined(CDNA3)
using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int;
int32x4_t* acc = (int32x4_t*) D.x;
acc[0] = __builtin_amdgcn_mfma_i32_16x16x32_i8(((int64_t*) A.x)[0],
((int64_t*) B.x)[0],
acc[0],
0, 0, 0);
#elif defined(CDNA2) || defined(CDNA)
#endif
#else
GGML_UNUSED(D);
GGML_UNUSED(A);
GGML_UNUSED(B);
NO_DEVICE_CODE;
#endif // NEW_MMA_AVAILABLE
}

static __device__ __forceinline__ void mma(
tile<32, 32, int> & D, const tile<32, 4, int> & A, const tile<32, 4, int> & B) {
#if defined(AMD_MMA_AVAILABLE)
#if defined(CDNA3)
using int32x16_t = __attribute__((__vector_size__(16 * sizeof(int)))) int;
int32x16_t* acc = (int32x16_t*) D.x;
acc[0] = __builtin_amdgcn_mfma_i32_32x32x16_i8(((int64_t*) A.x)[0],
((int64_t*) B.x)[0],
acc[0],
0, 0, 0);
#elif defined(CDNA2) || defined(CDNA)
#endif
#else
GGML_UNUSED(D);
GGML_UNUSED(A);
Expand Down
2 changes: 1 addition & 1 deletion ggml/src/ggml-cuda/mmq.cu
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
return false;
}

if (new_mma_available(cc)) {
if (new_mma_available(cc) || amd_mma_available(cc)) {
return true;
}

Expand Down
Loading