Skip to content

Commit 0543f92

Browse files
HIP: WMMA-MMQ kernels for RDNA 4 (ggml-org#17156)
* first commit naive test to enable mmq for RDNA4 * adding appropriate WMMA instructions * git rebase on top of master: fixing the correctness of the mat mul operations, updating layout mappings for RDNA4 * clean up merge conflicts * add comments and code clean up * PR clean up, addressed comments * enable MMQ fallback on RDNA4 * addressed comments: add guards in load generic, separate wmma branch for use_mmq function * Revert build-xcframework.sh * Formating: remove trailing whitespace * revert CMake files * clean up after rebase: remove duplicated change, revert cmake files * clean up after rebase: revert changes from build-xcframework.sh * clean up: remove extra space line in mma.cuh * Revert "clean up: remove extra space line in mma.cuh" This reverts commit b39ed57.
1 parent b61de2b commit 0543f92

File tree

3 files changed

+408
-168
lines changed

3 files changed

+408
-168
lines changed

ggml/src/ggml-cuda/mma.cuh

Lines changed: 101 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -73,34 +73,7 @@ namespace ggml_cuda_mma {
7373
static constexpr int I = I_;
7474
static constexpr int J = J_;
7575

76-
#if defined(GGML_USE_HIP)
77-
#if defined(RDNA4)
78-
static constexpr int ne = I * J / 32;
79-
T x[ne] = {0};
80-
81-
static constexpr __device__ bool supported() {
82-
if (I == 16 && J == 16) return true;
83-
return false;
84-
}
85-
86-
static __device__ __forceinline__ int get_i(const int l) {
87-
if constexpr (I == 16 && J == 16) {
88-
return 8 * (threadIdx.x / 16) + l;
89-
} else {
90-
NO_DEVICE_CODE;
91-
return -1;
92-
}
93-
}
94-
95-
static __device__ __forceinline__ int get_j(const int l) {
96-
if constexpr (I == 16 && J == 16) {
97-
return threadIdx.x % 16;
98-
} else {
99-
NO_DEVICE_CODE;
100-
return -1;
101-
}
102-
}
103-
#else
76+
#if defined(AMD_MFMA_AVAILABLE)
10477
static constexpr int ne = I * J / 64;
10578
T x[ne] = {0};
10679

@@ -146,7 +119,6 @@ namespace ggml_cuda_mma {
146119
return -1;
147120
}
148121
}
149-
#endif // defined(RDNA4)
150122
#elif __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
151123
static constexpr int ne = I * J / 32;
152124
T x[ne] = {0};
@@ -177,6 +149,34 @@ namespace ggml_cuda_mma {
177149
return -1;
178150
}
179151
}
152+
#elif defined(AMD_WMMA_AVAILABLE)
153+
#if defined(RDNA4)
154+
static constexpr int ne = I * J / 32;
155+
T x[ne] = {0};
156+
157+
static constexpr __device__ bool supported() {
158+
if (I == 16 && J == 16) return true;
159+
return false;
160+
}
161+
162+
static __device__ __forceinline__ int get_i(const int l) {
163+
if constexpr (I == 16 && J == 16) {
164+
return 8 * (threadIdx.x / 16) + l;
165+
} else {
166+
NO_DEVICE_CODE;
167+
return -1;
168+
}
169+
}
170+
171+
static __device__ __forceinline__ int get_j(const int l) {
172+
if constexpr (I == 16 && J == 16) {
173+
return threadIdx.x % 16;
174+
} else {
175+
NO_DEVICE_CODE;
176+
return -1;
177+
}
178+
}
179+
#endif
180180
#else
181181
static constexpr int ne = I * J / 32;
182182
T x[ne] = {0};
@@ -437,7 +437,20 @@ namespace ggml_cuda_mma {
437437
xi[0] = xs[0];
438438
}
439439
#elif defined(AMD_WMMA_AVAILABLE)
440-
ggml_cuda_memcpy_1<sizeof(t.x)>(t.x, xs0 + t.get_i(0) * stride + t.get_j(0));
440+
if constexpr (I == 16 && J == 4) {
441+
int64_t * xi = (int64_t *) t.x;
442+
const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I));
443+
xi[0] = xs[0];
444+
}else if constexpr (I == 16 && J == 8) {
445+
int64_t * xi = (int64_t *) t.x;
446+
const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 4 * (threadIdx.x / t.I));
447+
xi[0] = xs[0];
448+
449+
const int64_t * xs1 = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 4 * (threadIdx.x / t.I) + 2);
450+
xi[1] = xs1[0];
451+
}else{
452+
NO_DEVICE_CODE;
453+
}
441454
#else
442455
#pragma unroll
443456
for (int l = 0; l < t.ne; ++l) {
@@ -772,6 +785,36 @@ namespace ggml_cuda_mma {
772785
acc[0],
773786
0, 0, 0);
774787
#endif // defined(CDNA3)
788+
789+
#elif defined(AMD_WMMA_AVAILABLE)
790+
using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int;
791+
int32x2_t * a_vec = (int32x2_t *) A.x;
792+
int32x2_t * b_vec = (int32x2_t *) B.x;
793+
794+
using int32x8_t = __attribute__((__vector_size__(8 * sizeof(int)))) int;
795+
int32x8_t * acc = (int32x8_t *) D.x;
796+
797+
#if defined(RDNA4)
798+
799+
acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(
800+
true,
801+
a_vec[0],
802+
true,
803+
b_vec[0],
804+
acc[0],
805+
true
806+
);
807+
808+
acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(
809+
true,
810+
a_vec[1],
811+
true,
812+
b_vec[1],
813+
acc[0],
814+
true
815+
);
816+
#endif // defined(RDNA4)
817+
775818
#else
776819
GGML_UNUSED_VARS(D, A, B);
777820
NO_DEVICE_CODE;
@@ -798,6 +841,7 @@ namespace ggml_cuda_mma {
798841
acc[0],
799842
0, 0, 0);
800843
#endif // defined(CDNA3)
844+
801845
#else
802846
GGML_UNUSED_VARS(D, A, B);
803847
NO_DEVICE_CODE;
@@ -842,4 +886,31 @@ namespace ggml_cuda_mma {
842886
mma(D16[1], A16[1], B);
843887
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
844888
}
889+
890+
static __device__ __forceinline__ void mma(
891+
tile<16, 16, int> & D, const tile<16, 4, int> & A, const tile<16, 4, int> & B) {
892+
#if defined(AMD_WMMA_AVAILABLE)
893+
using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int;
894+
int32x2_t * a_vec = (int32x2_t *) A.x;
895+
int32x2_t * b_vec = (int32x2_t *) B.x;
896+
897+
using int32x8_t = __attribute__((__vector_size__(8 * sizeof(int)))) int;
898+
int32x8_t * acc = (int32x8_t *) D.x;
899+
900+
acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(
901+
true,
902+
a_vec[0],
903+
true,
904+
b_vec[0],
905+
acc[0],
906+
false
907+
);
908+
#else
909+
GGML_UNUSED(D);
910+
GGML_UNUSED(A);
911+
GGML_UNUSED(B);
912+
NO_DEVICE_CODE;
913+
#endif
914+
}
845915
}
916+

ggml/src/ggml-cuda/mmq.cu

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -306,5 +306,11 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
306306
return false;
307307
}
308308

309-
return (!GGML_CUDA_CC_IS_RDNA4(cc) && !GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
309+
if (amd_wmma_available(cc)) {
310+
if (GGML_CUDA_CC_IS_RDNA4(cc)) {
311+
return true;
312+
}
313+
}
314+
315+
return (!GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
310316
}

0 commit comments

Comments
 (0)