Skip to content

Commit eed07b8

Browse files
committed
Revert commit "HIP: Enable Matrix cores for MMQ Kernels, Enable stream-K for CDNA"
1 parent 21b7d0a commit eed07b8

File tree

5 files changed

+700
-1295
lines changed

5 files changed

+700
-1295
lines changed

ggml/src/ggml-cuda/common.cuh

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@
5656
#define GGML_CUDA_CC_GCN4 (GGML_CUDA_CC_OFFSET_AMD + 0x803) // Tonga, Fiji, Polaris, minimum for fast fp16
5757
#define GGML_CUDA_CC_VEGA (GGML_CUDA_CC_OFFSET_AMD + 0x900) // Vega56/64, minimum for fp16 dual issue
5858
#define GGML_CUDA_CC_VEGA20 (GGML_CUDA_CC_OFFSET_AMD + 0x906) // MI50/Radeon VII, minimum for dp4a
59-
#define GGML_CUDA_CC_CDNA1 (GGML_CUDA_CC_OFFSET_AMD + 0x908) // MI100, minimum for MFMA, acc registers
59+
#define GGML_CUDA_CC_CDNA (GGML_CUDA_CC_OFFSET_AMD + 0x908) // MI100, minimum for MFMA, acc registers
6060
#define GGML_CUDA_CC_CDNA2 (GGML_CUDA_CC_OFFSET_AMD + 0x910) // MI210, minimum acc register renameing
6161
#define GGML_CUDA_CC_CDNA3 (GGML_CUDA_CC_OFFSET_AMD + 0x942) // MI300
6262

@@ -72,9 +72,8 @@
7272
#define GGML_CUDA_CC_IS_RDNA2(cc) (cc >= GGML_CUDA_CC_RDNA2 && cc < GGML_CUDA_CC_RDNA3)
7373
#define GGML_CUDA_CC_IS_RDNA3(cc) (cc >= GGML_CUDA_CC_RDNA3 && cc < GGML_CUDA_CC_RDNA4)
7474
#define GGML_CUDA_CC_IS_RDNA4(cc) (cc >= GGML_CUDA_CC_RDNA4)
75-
#define GGML_CUDA_CC_IS_GCN(cc) (cc > GGML_CUDA_CC_OFFSET_AMD && cc < GGML_CUDA_CC_CDNA1)
76-
#define GGML_CUDA_CC_IS_CDNA(cc) (cc >= GGML_CUDA_CC_CDNA1 && cc < GGML_CUDA_CC_RDNA1)
77-
#define GGML_CUDA_CC_IS_CDNA3(cc) (cc >= GGML_CUDA_CC_CDNA3 && cc < GGML_CUDA_CC_RDNA1)
75+
#define GGML_CUDA_CC_IS_GCN(cc) (cc > GGML_CUDA_CC_OFFSET_AMD && cc < GGML_CUDA_CC_CDNA)
76+
#define GGML_CUDA_CC_IS_CDNA(cc) (cc >= GGML_CUDA_CC_CDNA && cc < GGML_CUDA_CC_RDNA1)
7877

7978
// Moore Threads
8079
#define GGML_CUDA_CC_QY1 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x210) // MTT S80, MTT S3000
@@ -231,10 +230,6 @@ typedef float2 dfloat2;
231230
#define FP16_MMA_AVAILABLE
232231
#endif // defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || (defined(GGML_HIP_ROCWMMA_FATTN_GFX12) && defined(RDNA4)))
233232

234-
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && defined(CDNA3)
235-
#define AMD_MFMA_AVAILABLE
236-
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && defined(CDNA3)
237-
238233
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
239234
#define NEW_MMA_AVAILABLE
240235
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
@@ -297,11 +292,6 @@ static bool fp32_mma_hardware_available(const int cc) {
297292
return GGML_CUDA_CC_IS_CDNA(cc);
298293
}
299294

300-
// AMD CDNA3 matrix cores.. Will add support for other CDNA generations later.
301-
static bool amd_mfma_available(const int cc) {
302-
return cc >= GGML_CUDA_CC_OFFSET_AMD && GGML_CUDA_CC_IS_CDNA3(cc);
303-
}
304-
305295
// Volta technically had FP16 tensor cores but they work very differently compared to Turing and later.
306296
static bool new_mma_available(const int cc) {
307297
return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_TURING;

ggml/src/ggml-cuda/mma.cuh

Lines changed: 3 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,7 @@
1212
// The methods get_i and get_j can be used to get the physical 32 bit index of the lth element of a thread within a tile.
1313
// All matrix tiles have ne physical 32 bit elements per warp.
1414
//
15-
// As described in the PTX documentation, all pointers for load_ldmatrix must be to shared memory and aligned to 16 bytes.
16-
// The API in this file also assumes that the pointers for load_generic are aligned to 16 bytes, unaligned pointers are considered undefined behavior.
15+
// As described in the documentation, all pointers for load_ldmatrix must be to shared memory and aligned to 16 bytes.
1716

1817
#include "common.cuh"
1918

@@ -67,44 +66,7 @@ namespace ggml_cuda_mma {
6766
struct tile {
6867
static constexpr int I = I_;
6968
static constexpr int J = J_;
70-
71-
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
72-
static constexpr int ne = I * J / 64;
73-
T x[ne] = {0};
74-
75-
static __device__ __forceinline__ int get_i(const int l) {
76-
if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8>
77-
return threadIdx.x % 16;
78-
} else if constexpr (I == 16 && J == 8) {
79-
return threadIdx.x % 16;
80-
} else if constexpr (I == 32 && J == 4) {
81-
return threadIdx.x % 32;
82-
} else if constexpr (I == 16 && J == 16) {
83-
return 4 * (threadIdx.x / 16) + l;
84-
} else if constexpr (I == 32 && J == 32) {
85-
return 4 * (threadIdx.x / 32) + 8 * (l / 4) + (l % 4);
86-
} else {
87-
static_assert(I == -1 && J == -1, "template specialization not implemented");
88-
}
89-
}
90-
91-
static __device__ __forceinline__ int get_j(const int l) {
92-
if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8>
93-
return (2 * ((threadIdx.x / 16) % 2) + l);
94-
} else if constexpr (I == 16 && J == 8) {
95-
return 2 * (threadIdx.x / 16) + l;
96-
} else if constexpr (I == 32 && J == 4) {
97-
return 2 * (threadIdx.x / 32) + l;
98-
} else if constexpr (I == 16 && J == 16) {
99-
return threadIdx.x % 16;
100-
} else if constexpr (I == 32 && J == 32) {
101-
return threadIdx.x % 32;
102-
} else {
103-
static_assert(I == -1 && J == -1, "template specialization not implemented");
104-
}
105-
}
106-
#else
107-
static constexpr int ne = I * J / 32;
69+
static constexpr int ne = I * J / WARP_SIZE;
10870
T x[ne] = {0};
10971

11072
static __device__ __forceinline__ int get_i(const int l) {
@@ -132,7 +94,6 @@ namespace ggml_cuda_mma {
13294
static_assert(I == -1 && J == -1, "template specialization not implemented");
13395
}
13496
}
135-
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
13697
};
13798

13899
template <int I_, int J_>
@@ -187,23 +148,10 @@ namespace ggml_cuda_mma {
187148

188149
template <int I, int J, typename T>
189150
static __device__ __forceinline__ void load_generic(tile<I, J, T> & t, const T * __restrict__ xs0, const int stride) {
190-
#if defined(AMD_MFMA_AVAILABLE)
191-
if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8>
192-
#pragma unroll
193-
for (int l = 0; l < t.ne; ++l) {
194-
t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)];
195-
}
196-
} else {
197-
int64_t * xi = (int64_t *) t.x;
198-
const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I));
199-
xi[0] = xs[0];
200-
}
201-
#else
202151
#pragma unroll
203152
for (int l = 0; l < t.ne; ++l) {
204153
t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)];
205154
}
206-
#endif // defined(AMD_MFMA_AVAILABLE)
207155
}
208156

209157
template <typename T>
@@ -238,7 +186,7 @@ namespace ggml_cuda_mma {
238186
template <typename T>
239187
static __device__ __forceinline__ void load_ldmatrix(
240188
tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) {
241-
#if defined(NEW_MMA_AVAILABLE)
189+
#ifdef NEW_MMA_AVAILABLE
242190
int * xi = (int * ) t.x;
243191
const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2);
244192
asm volatile("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];"
@@ -445,60 +393,4 @@ namespace ggml_cuda_mma {
445393
NO_DEVICE_CODE;
446394
#endif // NEW_MMA_AVAILABLE
447395
}
448-
449-
static __device__ __forceinline__ void mma(
450-
tile<16, 16, int> & D, const tile<16, 8, int> & A, const tile<16, 8, int> & B) {
451-
#if defined(AMD_MFMA_AVAILABLE)
452-
using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int;
453-
int32x4_t * acc = (int32x4_t *) D.x;
454-
#if defined(CDNA3)
455-
acc[0] = __builtin_amdgcn_mfma_i32_16x16x32_i8(((int64_t *) A.x)[0],
456-
((int64_t *) B.x)[0],
457-
acc[0],
458-
0, 0, 0);
459-
#elif defined(CDNA2) || defined(CDNA)
460-
acc[0] = __builtin_amdgcn_mfma_i32_16x16x16i8(A.x[0],
461-
B.x[0],
462-
acc[0],
463-
0, 0, 0);
464-
acc[0] = __builtin_amdgcn_mfma_i32_16x16x16i8(A.x[1],
465-
B.x[1],
466-
acc[0],
467-
0, 0, 0);
468-
#endif // defined(CDNA3)
469-
#else
470-
GGML_UNUSED(D);
471-
GGML_UNUSED(A);
472-
GGML_UNUSED(B);
473-
NO_DEVICE_CODE;
474-
#endif // AMD_MFMA_AVAILABLE
475-
}
476-
477-
static __device__ __forceinline__ void mma(
478-
tile<32, 32, int> & D, const tile<32, 4, int> & A, const tile<32, 4, int> & B) {
479-
#if defined(AMD_MFMA_AVAILABLE)
480-
using int32x16_t = __attribute__((__vector_size__(16 * sizeof(int)))) int;
481-
int32x16_t * acc = (int32x16_t *) D.x;
482-
#if defined(CDNA3)
483-
acc[0] = __builtin_amdgcn_mfma_i32_32x32x16_i8(((int64_t *) A.x)[0],
484-
((int64_t *) B.x)[0],
485-
acc[0],
486-
0, 0, 0);
487-
#elif defined(CDNA2) || defined(CDNA)
488-
acc[0] = __builtin_amdgcn_mfma_i32_32x32x8i8(A.x[0],
489-
B.x[0],
490-
acc[0],
491-
0, 0, 0);
492-
acc[0] = __builtin_amdgcn_mfma_i32_32x32x8i8(A.x[1],
493-
B.x[1],
494-
acc[0],
495-
0, 0, 0);
496-
#endif // defined(CDNA3)
497-
#else
498-
GGML_UNUSED(D);
499-
GGML_UNUSED(A);
500-
GGML_UNUSED(B);
501-
NO_DEVICE_CODE;
502-
#endif // AMD_MFMA_AVAILABLE
503-
}
504396
}

ggml/src/ggml-cuda/mmq.cu

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,7 @@ void ggml_cuda_mul_mat_q(
109109
const int64_t s03 = src0->nb[3] / ts_src0;
110110
const int64_t s3 = dst->nb[3] / ts_dst;
111111

112-
const bool use_stream_k = ((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA)
113-
|| (GGML_CUDA_CC_IS_AMD(cc) && GGML_CUDA_CC_IS_CDNA3(cc)));
112+
const bool use_stream_k = GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA;
114113

115114
if (!ids) {
116115
const size_t nbytes_src1_q8_1 = ne13*ne12 * ne11*ne10_padded * sizeof(block_q8_1)/QK8_1 +
@@ -251,9 +250,8 @@ void ggml_cuda_op_mul_mat_q(
251250
// The stream-k decomposition is only faster for recent NVIDIA GPUs.
252251
// Also its fixup needs to allocate a temporary buffer in the memory pool.
253252
// There are multiple parallel CUDA streams for src1_ncols != ne11 which would introduce a race condition for this buffer.
254-
const bool use_stream_k = ((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA)
255-
|| (GGML_CUDA_CC_IS_AMD(cc) && GGML_CUDA_CC_IS_CDNA3(cc)))
256-
&& src1_ncols == ne11;
253+
const bool use_stream_k = GGML_CUDA_CC_IS_NVIDIA(cc) &&
254+
ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA && src1_ncols == ne11;
257255
const mmq_args args = {
258256
src0_dd_i, src0->type, (const int *) src1_ddq_i, nullptr, nullptr, dst_dd_i,
259257
ne00, row_diff, src1_ncols, stride01, ne11, nrows_dst,
@@ -308,7 +306,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
308306
return false;
309307
}
310308

311-
if (new_mma_available(cc) || amd_mfma_available(cc)) {
309+
if (new_mma_available(cc)) {
312310
return true;
313311
}
314312

0 commit comments

Comments
 (0)