Skip to content

Commit 95e571c

Browse files
author
hzhang13
committed
mmf 4 rdna4
1 parent 338074c commit 95e571c

File tree

4 files changed

+206
-6
lines changed

4 files changed

+206
-6
lines changed

ggml/src/ggml-cuda/common.cuh

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,10 @@ static const char * cu_get_error_str(CUresult err) {
224224
#define AMD_MFMA_AVAILABLE
225225
#endif // defined(GGML_USE_HIP) && defined(CDNA) && !defined(GGML_HIP_NO_MMQ_MFMA)
226226

227+
#if defined(GGML_USE_HIP) && defined(RDNA4) && !defined(GGML_HIP_NO_WMMA)
228+
#define AMD_WMMA_AVAILABLE
229+
#endif // defined(GGML_USE_HIP) && defined(CDNA) && !defined(GGML_HIP_NO_MMQ_MFMA)
230+
227231
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
228232
#define TURING_MMA_AVAILABLE
229233
#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
@@ -278,6 +282,14 @@ static bool amd_mfma_available(const int cc) {
278282
#endif //!defined(GGML_HIP_NO_MMQ_MFMA)
279283
}
280284

285+
static bool amd_wmma_available(const int cc) {
286+
#if !defined(GGML_HIP_NO_WMMA)
287+
return GGML_CUDA_CC_IS_RDNA4(cc);
288+
#else
289+
return false;
290+
#endif //!defined(AMD_WMMA_AVAILABLE)
291+
}
292+
281293
// Volta technically had FP16 tensor cores but they work very differently compared to Turing and later.
282294
static bool turing_mma_available(const int cc) {
283295
return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_TURING;

ggml/src/ggml-cuda/mma.cuh

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ namespace ggml_cuda_mma {
7070
static constexpr int J = J_;
7171

7272
#if defined(GGML_USE_HIP)
73+
#if defined(CDNA)
7374
static constexpr int ne = I * J / 64;
7475
T x[ne] = {0};
7576

@@ -104,6 +105,30 @@ namespace ggml_cuda_mma {
104105
static_assert(I == -1 && J == -1, "template specialization not implemented");
105106
}
106107
}
108+
#elif defined(RDNA4)
109+
static constexpr int ne = I * J / 32;
110+
T x[ne] = {0};
111+
112+
static __device__ __forceinline__ int get_i(const int l) {
113+
if constexpr (I == 16 && J == 16) {
114+
return 8 * (threadIdx.x / 16) + l;
115+
} else if constexpr (I == 16 && J == 8) { // dummy shape to handle TF32, just make the compiler happle, don't use it in the real case
116+
return 4 * (threadIdx.x / 16) + l;
117+
} else {
118+
static_assert(I == -1 && J == -1, "template specialization not implemented");
119+
}
120+
}
121+
122+
static __device__ __forceinline__ int get_j(const int l) {
123+
if constexpr (I == 16 && J == 16) {
124+
return threadIdx.x % 16;
125+
} else if constexpr (I == 16 && J == 8) { // dummy shape to handle TF32, just make the compiler happle, don't use it in the real case
126+
return threadIdx.x % 16;
127+
} else {
128+
static_assert(I == -1 && J == -1, "template specialization not implemented");
129+
}
130+
}
131+
#endif // defined(CDNA)
107132
#else
108133
static constexpr int ne = I * J / 32;
109134
T x[ne] = {0};
@@ -140,6 +165,29 @@ namespace ggml_cuda_mma {
140165
struct tile<I_, J_, half2> {
141166
static constexpr int I = I_;
142167
static constexpr int J = J_;
168+
169+
#if defined(AMD_WMMA_AVAILABLE)
170+
#if defined(RDNA4)
171+
static constexpr int ne = I * J / 32;
172+
half2 x[ne] = {{0.0f, 0.0f}};
173+
174+
static __device__ __forceinline__ int get_i(const int l) {
175+
if constexpr (I == 16 && J == 8) {
176+
return threadIdx.x % 16;
177+
} else {
178+
static_assert(I == -1 && J == -1, "template specialization not implemented");
179+
}
180+
}
181+
182+
static __device__ __forceinline__ int get_j(const int l) {
183+
if constexpr (I == 16 && J == 8) {
184+
return 4 * (threadIdx.x / 16) + l;
185+
} else {
186+
static_assert(I == -1 && J == -1, "template specialization not implemented");
187+
}
188+
}
189+
#endif // defined(RDNA4)
190+
#else
143191
static constexpr int ne = I * J / WARP_SIZE;
144192
half2 x[ne] = {{0.0f, 0.0f}};
145193

@@ -166,12 +214,36 @@ namespace ggml_cuda_mma {
166214
static_assert(I == -1 && J == -1, "template specialization not implemented");
167215
}
168216
}
217+
#endif // defined(GGML_USE_HIP)
169218
};
170219

171220
template <int I_, int J_>
172221
struct tile<I_, J_, nv_bfloat162> {
173222
static constexpr int I = I_;
174223
static constexpr int J = J_;
224+
225+
#if defined(AMD_WMMA_AVAILABLE)
226+
#if defined(RDNA4)
227+
static constexpr int ne = I * J / 32;
228+
nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
229+
230+
static __device__ __forceinline__ int get_i(const int l) {
231+
if constexpr (I == 16 && J == 8) {
232+
return threadIdx.x % 16;
233+
} else {
234+
static_assert(I == -1 && J == -1, "template specialization not implemented");
235+
}
236+
}
237+
238+
static __device__ __forceinline__ int get_j(const int l) {
239+
if constexpr (I == 16 && J == 8) {
240+
return 4 * (threadIdx.x / 16) + l;
241+
} else {
242+
static_assert(I == -1 && J == -1, "template specialization not implemented");
243+
}
244+
}
245+
#endif // defined(RDNA4)
246+
#else
175247
static constexpr int ne = I * J / WARP_SIZE;
176248
nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
177249

@@ -198,6 +270,7 @@ namespace ggml_cuda_mma {
198270
static_assert(I == -1 && J == -1, "template specialization not implemented");
199271
}
200272
}
273+
#endif // defined(AMD_WMMA_AVAILABLE)
201274
};
202275

203276
template <int I, int J>
@@ -231,6 +304,19 @@ namespace ggml_cuda_mma {
231304
const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I));
232305
xi[0] = xs[0];
233306
}
307+
#elif defined(AMD_WMMA_AVAILABLE)
308+
#if defined(RDNA4)
309+
// Special tile size to load <16, 8> as <16, 16> for half2 and __hip_bfloat162
310+
if constexpr (I == 16 && J == 8 && (std::is_same<T, half2>::value || std::is_same<T, nv_bfloat162>::value)) {
311+
constexpr int RDNA4_WMMA_MEM_N = 4;
312+
using TxN_t = __attribute__((ext_vector_type(RDNA4_WMMA_MEM_N))) int;
313+
reinterpret_cast<TxN_t&>(t.x[0]) = reinterpret_cast<const TxN_t&>(xs0[t.get_i(0) * stride + t.get_j(0)]);
314+
} else {
315+
constexpr int RDNA4_WMMA_MEM_N = 8;
316+
using TxN_t = __attribute__((ext_vector_type(RDNA4_WMMA_MEM_N))) T;
317+
reinterpret_cast<TxN_t&>(t.x[0]) = reinterpret_cast<const TxN_t&>(xs0[t.get_i(0) * stride + t.get_j(0)]);
318+
}
319+
#endif // defined(RDNA4)
234320
#else
235321
#pragma unroll
236322
for (int l = 0; l < t.ne; ++l) {
@@ -461,6 +547,25 @@ namespace ggml_cuda_mma {
461547
#endif // AMPERE_MMA_AVAILABLE
462548
}
463549

550+
static __device__ __forceinline__ void mma(
551+
tile<16, 16, float> & D, const tile<16, 8, float> & A, const tile<16, 8, float> & B) {
552+
#ifdef AMPERE_MMA_AVAILABLE
553+
const int * Axi = (const int *) A.x;
554+
const int * Bxi = (const int *) B.x;
555+
int * Dxi = (int *) D.x;
556+
asm("mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
557+
: "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
558+
: "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]));
559+
asm("mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, "
560+
"%2, %3};"
561+
: "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
562+
: "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1]), "r"(Bxi[3]));
563+
#else
564+
GGML_UNUSED_VARS(D, A, B);
565+
NO_DEVICE_CODE;
566+
#endif // AMPERE_MMA_AVAILABLE
567+
}
568+
464569
static __device__ __forceinline__ void mma(
465570
tile<16, 16, float> & D, const tile<16, 8, half2> & A, const tile<16, 8, half2> & B) {
466571
#ifdef TURING_MMA_AVAILABLE
@@ -489,12 +594,48 @@ namespace ggml_cuda_mma {
489594
: "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
490595
: "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[3]));
491596
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
597+
#elif defined(AMD_WMMA_AVAILABLE)
598+
#if defined(RDNA4)
599+
using halfx8_t = __attribute__((ext_vector_type(8))) _Float16;
600+
using floatx8_t = __attribute__((ext_vector_type(8))) float;
601+
floatx8_t& acc_frag = reinterpret_cast<floatx8_t&>(D.x[0]);
602+
const halfx8_t& a_frag = reinterpret_cast<const halfx8_t&>(A.x[0]);
603+
const halfx8_t& b_frag = reinterpret_cast<const halfx8_t&>(B.x[0]);
604+
acc_frag = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12(a_frag, b_frag, acc_frag);
605+
#endif // defined(RDNA4)
492606
#else
493607
GGML_UNUSED_VARS(D, A, B);
494608
NO_DEVICE_CODE;
495609
#endif // TURING_MMA_AVAILABLE
496610
}
497611

612+
static __device__ __forceinline__ void mma(
613+
tile<16, 16, float> & D, const tile<16, 8, nv_bfloat162> & A, const tile<16, 8, nv_bfloat162> & B) {
614+
#ifdef AMPERE_MMA_AVAILABLE
615+
const int * Axi = (const int *) A.x;
616+
const int * Bxi = (const int *) B.x;
617+
int * Dxi = (int *) D.x;
618+
asm("mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
619+
: "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
620+
: "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[2]));
621+
asm("mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
622+
: "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
623+
: "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1]), "r"(Bxi[3]));
624+
#elif defined(AMD_WMMA_AVAILABLE)
625+
#if defined(RDNA4)
626+
using bf16x8_t = __attribute__((ext_vector_type(8))) __bf16;
627+
using floatx8_t = __attribute__((ext_vector_type(8))) float;
628+
floatx8_t& acc_frag = reinterpret_cast<floatx8_t&>(D.x[0]);
629+
const bf16x8_t& a_frag = reinterpret_cast<const bf16x8_t&>(A.x[0]);
630+
const bf16x8_t& b_frag = reinterpret_cast<const bf16x8_t&>(B.x[0]);
631+
acc_frag = __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12(a_frag, b_frag, acc_frag);
632+
#endif // defined(RDNA4)
633+
#else
634+
GGML_UNUSED_VARS(D, A, B);
635+
NO_DEVICE_CODE;
636+
#endif // AMPERE_MMA_AVAILABLE
637+
}
638+
498639
static __device__ __forceinline__ void mma(
499640
tile<16, 16, int> & D, const tile<16, 8, int> & A, const tile<16, 8, int> & B) {
500641
#if defined(AMD_MFMA_AVAILABLE)

ggml/src/ggml-cuda/mmf.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,9 +148,9 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const
148148
case GGML_TYPE_F32:
149149
return ampere_mma_available(cc);
150150
case GGML_TYPE_F16:
151-
return turing_mma_available(cc);
151+
return turing_mma_available(cc) || amd_wmma_available(cc);
152152
case GGML_TYPE_BF16:
153-
return ampere_mma_available(cc);
153+
return ampere_mma_available(cc) || amd_wmma_available(cc);
154154
default:
155155
return false;
156156
}

ggml/src/ggml-cuda/mmf.cuh

Lines changed: 51 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,16 @@ static __global__ void mul_mat_f(
2727
const int stride_col_id, const int stride_row_id,
2828
const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
2929
const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
30-
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
30+
#if defined(AMD_WMMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
31+
#if defined(AMD_WMMA_AVAILABLE)
32+
typedef tile<16, 8, T> tile_A;
33+
typedef tile<16, 8, T> tile_B;
34+
typedef tile<16, 16, float> tile_C;
35+
#else
3136
typedef tile<16, 8, T> tile_A;
3237
typedef tile< 8, 8, T> tile_B;
3338
typedef tile<16, 8, float> tile_C;
39+
#endif // defined(AMD_MFMA_AVAILABLE)
3440

3541
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
3642
constexpr int tile_k_padded = warp_size + 4;
@@ -151,11 +157,31 @@ static __global__ void mul_mat_f(
151157

152158
if constexpr (!has_ids) {
153159
const float2 tmp = j < cols_per_block ? y2[j*stride_col_y + col] : make_float2(0.0f, 0.0f);
160+
#if !defined(GGML_USE_HIP)
154161
tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y};
162+
#else
163+
if constexpr (std::is_same<T, half2>::value) {
164+
tile_xy[j0*tile_k_padded + threadIdx.x] = __float22half2_rn(tmp);
165+
} else if constexpr (std::is_same<T, nv_bfloat162>::value) {
166+
tile_xy[j0*tile_k_padded + threadIdx.x] = __float22bfloat162_rn(tmp);
167+
} else {
168+
static_assert(0, "unsupported type");
169+
}
170+
#endif // !defined(GGML_USE_HIP)
155171
} else {
156172
const bool valid = j < cols_per_block && (col_base + j) < ncols_dst_total && slot_map[j] >= 0;
157173
float2 tmp = valid ? *(const float2*) &y[slot_map[j]*stride_channel_y + 2*(j*stride_col_y + col)] : make_float2(0.0f, 0.0f);
174+
#if !defined(GGML_USE_HIP)
158175
tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y};
176+
#else
177+
if constexpr (std::is_same<T, half2>::value) {
178+
tile_xy[j0*tile_k_padded + threadIdx.x] = __float22half2_rn(tmp);
179+
} else if constexpr (std::is_same<T, nv_bfloat162>::value) {
180+
tile_xy[j0*tile_k_padded + threadIdx.x] = __float22bfloat162_rn(tmp);
181+
} else {
182+
static_assert(std::is_same_v<T, void>, "unsupported type");
183+
}
184+
#endif // !defined(GGML_USE_HIP)
159185
}
160186
}
161187
} else {
@@ -229,7 +255,7 @@ static __global__ void mul_mat_f(
229255
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
230256
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
231257
NO_DEVICE_CODE;
232-
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
258+
#endif // defined(AMD_WMMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
233259
}
234260

235261

@@ -244,10 +270,16 @@ static __global__ void mul_mat_f_ids(
244270
const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
245271
const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
246272
const uint3 sis1_fd, const uint3 nch_fd) {
247-
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
273+
#if defined(AMD_WMMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
274+
#if defined(AMD_WMMA_AVAILABLE)
275+
typedef tile<16, 8, T> tile_A;
276+
typedef tile<16, 8, T> tile_B;
277+
typedef tile<16, 16, float> tile_C;
278+
#else
248279
typedef tile<16, 8, T> tile_A;
249280
typedef tile< 8, 8, T> tile_B;
250281
typedef tile<16, 8, float> tile_C;
282+
#endif // defined(AMD_MFMA_AVAILABLE)
251283

252284
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
253285
constexpr int tile_k_padded = warp_size + 4;
@@ -389,7 +421,17 @@ static __global__ void mul_mat_f_ids(
389421
#pragma unroll
390422
for (int j0 = 0; j0 < tile_B::I; ++j0) {
391423
const float2 tmp = vals_buf[curr_buf][j0];
424+
#if !defined(GGML_USE_HIP)
392425
tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y};
426+
#else
427+
if constexpr (std::is_same<T, half2>::value) {
428+
tile_xy[j0*tile_k_padded + threadIdx.x] = __float22half2_rn(tmp);
429+
} else if constexpr (std::is_same<T, nv_bfloat162>::value) {
430+
tile_xy[j0*tile_k_padded + threadIdx.x] = __float22bfloat162_rn(tmp);
431+
} else {
432+
static_assert(std::is_same_v<T, void>, "unsupported type");
433+
}
434+
#endif // !defined(GGML_USE_HIP)
393435
}
394436

395437
if (itB + 1 < ntB) {
@@ -473,7 +515,7 @@ static __global__ void mul_mat_f_ids(
473515
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
474516
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, sis1_fd, nch_fd);
475517
NO_DEVICE_CODE;
476-
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
518+
#endif // defined(AMD_WMMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
477519
}
478520

479521
template<typename T, int cols_per_block, int nwarps>
@@ -533,8 +575,13 @@ void mul_mat_f_cuda(
533575
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
534576
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
535577
cudaStream_t stream, const mmf_ids_data * ids_data) {
578+
#if defined(GGML_USE_HIP)
579+
typedef tile<16, 8, T> tile_A;
580+
typedef tile<16, 8, T> tile_B;
581+
#else
536582
typedef tile<16, 8, T> tile_A;
537583
typedef tile< 8, 8, T> tile_B;
584+
#endif // defined(GGML_USE_HIP)
538585

539586
GGML_ASSERT(ncols_x % 2 == 0);
540587
GGML_ASSERT(stride_row % 2 == 0);

0 commit comments

Comments
 (0)