@@ -70,6 +70,7 @@ namespace ggml_cuda_mma {
7070 static constexpr int J = J_;
7171
7272#if defined(GGML_USE_HIP)
73+ #if defined(CDNA)
7374 static constexpr int ne = I * J / 64 ;
7475 T x[ne] = {0 };
7576
@@ -104,6 +105,30 @@ namespace ggml_cuda_mma {
104105 static_assert (I == -1 && J == -1 , " template specialization not implemented" );
105106 }
106107 }
108+ #elif defined(RDNA4)
109+ static constexpr int ne = I * J / 32 ;
110+ T x[ne] = {0 };
111+
112+ static __device__ __forceinline__ int get_i (const int l) {
113+ if constexpr (I == 16 && J == 16 ) {
114+ return 8 * (threadIdx .x / 16 ) + l;
115+ } else if constexpr (I == 16 && J == 8 ) { // dummy shape to handle TF32, just make the compiler happle, don't use it in the real case
116+ return 4 * (threadIdx .x / 16 ) + l;
117+ } else {
118+ static_assert (I == -1 && J == -1 , " template specialization not implemented" );
119+ }
120+ }
121+
122+ static __device__ __forceinline__ int get_j (const int l) {
123+ if constexpr (I == 16 && J == 16 ) {
124+ return threadIdx .x % 16 ;
125+ } else if constexpr (I == 16 && J == 8 ) { // dummy shape to handle TF32, just make the compiler happle, don't use it in the real case
126+ return threadIdx .x % 16 ;
127+ } else {
128+ static_assert (I == -1 && J == -1 , " template specialization not implemented" );
129+ }
130+ }
131+ #endif // defined(CDNA)
107132#else
108133 static constexpr int ne = I * J / 32 ;
109134 T x[ne] = {0 };
@@ -140,6 +165,29 @@ namespace ggml_cuda_mma {
140165 struct tile <I_, J_, half2> {
141166 static constexpr int I = I_;
142167 static constexpr int J = J_;
168+
169+ #if defined(AMD_WMMA_AVAILABLE)
170+ #if defined(RDNA4)
171+ static constexpr int ne = I * J / 32 ;
172+ half2 x[ne] = {{0 .0f , 0 .0f }};
173+
174+ static __device__ __forceinline__ int get_i (const int l) {
175+ if constexpr (I == 16 && J == 8 ) {
176+ return threadIdx .x % 16 ;
177+ } else {
178+ static_assert (I == -1 && J == -1 , " template specialization not implemented" );
179+ }
180+ }
181+
182+ static __device__ __forceinline__ int get_j (const int l) {
183+ if constexpr (I == 16 && J == 8 ) {
184+ return 4 * (threadIdx .x / 16 ) + l;
185+ } else {
186+ static_assert (I == -1 && J == -1 , " template specialization not implemented" );
187+ }
188+ }
189+ #endif // defined(RDNA4)
190+ #else
143191 static constexpr int ne = I * J / WARP_SIZE;
144192 half2 x[ne] = {{0 .0f , 0 .0f }};
145193
@@ -166,12 +214,36 @@ namespace ggml_cuda_mma {
166214 static_assert (I == -1 && J == -1 , " template specialization not implemented" );
167215 }
168216 }
217+ #endif // defined(GGML_USE_HIP)
169218 };
170219
171220 template <int I_, int J_>
172221 struct tile <I_, J_, nv_bfloat162> {
173222 static constexpr int I = I_;
174223 static constexpr int J = J_;
224+
225+ #if defined(AMD_WMMA_AVAILABLE)
226+ #if defined(RDNA4)
227+ static constexpr int ne = I * J / 32 ;
228+ nv_bfloat162 x[ne] = {{0 .0f , 0 .0f }};
229+
230+ static __device__ __forceinline__ int get_i (const int l) {
231+ if constexpr (I == 16 && J == 8 ) {
232+ return threadIdx .x % 16 ;
233+ } else {
234+ static_assert (I == -1 && J == -1 , " template specialization not implemented" );
235+ }
236+ }
237+
238+ static __device__ __forceinline__ int get_j (const int l) {
239+ if constexpr (I == 16 && J == 8 ) {
240+ return 4 * (threadIdx .x / 16 ) + l;
241+ } else {
242+ static_assert (I == -1 && J == -1 , " template specialization not implemented" );
243+ }
244+ }
245+ #endif // defined(RDNA4)
246+ #else
175247 static constexpr int ne = I * J / WARP_SIZE;
176248 nv_bfloat162 x[ne] = {{0 .0f , 0 .0f }};
177249
@@ -198,6 +270,7 @@ namespace ggml_cuda_mma {
198270 static_assert (I == -1 && J == -1 , " template specialization not implemented" );
199271 }
200272 }
273+ #endif // defined(AMD_WMMA_AVAILABLE)
201274 };
202275
203276 template <int I, int J>
@@ -231,6 +304,19 @@ namespace ggml_cuda_mma {
231304 const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx .x % t.I ) * stride + 2 * (threadIdx .x / t.I ));
232305 xi[0 ] = xs[0 ];
233306 }
307+ #elif defined(AMD_WMMA_AVAILABLE)
308+ #if defined(RDNA4)
309+ // Special tile size to load <16, 8> as <16, 16> for half2 and __hip_bfloat162
310+ if constexpr (I == 16 && J == 8 && (std::is_same<T, half2>::value || std::is_same<T, nv_bfloat162>::value)) {
311+ constexpr int RDNA4_WMMA_MEM_N = 4 ;
312+ using TxN_t = __attribute__ ((ext_vector_type (RDNA4_WMMA_MEM_N))) int ;
313+ reinterpret_cast <TxN_t&>(t.x [0 ]) = reinterpret_cast <const TxN_t&>(xs0[t.get_i (0 ) * stride + t.get_j (0 )]);
314+ } else {
315+ constexpr int RDNA4_WMMA_MEM_N = 8 ;
316+ using TxN_t = __attribute__ ((ext_vector_type (RDNA4_WMMA_MEM_N))) T;
317+ reinterpret_cast <TxN_t&>(t.x [0 ]) = reinterpret_cast <const TxN_t&>(xs0[t.get_i (0 ) * stride + t.get_j (0 )]);
318+ }
319+ #endif // defined(RDNA4)
234320#else
235321#pragma unroll
236322 for (int l = 0 ; l < t.ne ; ++l) {
@@ -461,6 +547,25 @@ namespace ggml_cuda_mma {
461547#endif // AMPERE_MMA_AVAILABLE
462548 }
463549
550+ static __device__ __forceinline__ void mma (
551+ tile<16 , 16 , float > & D, const tile<16 , 8 , float > & A, const tile<16 , 8 , float > & B) {
552+ #ifdef AMPERE_MMA_AVAILABLE
553+ const int * Axi = (const int *) A.x ;
554+ const int * Bxi = (const int *) B.x ;
555+ int * Dxi = (int *) D.x ;
556+ asm (" mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
557+ : " +r" (Dxi[0 ]), " +r" (Dxi[1 ]), " +r" (Dxi[2 ]), " +r" (Dxi[3 ])
558+ : " r" (Axi[0 ]), " r" (Axi[1 ]), " r" (Axi[2 ]), " r" (Axi[3 ]), " r" (Bxi[0 ]), " r" (Bxi[1 ]));
559+ asm (" mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, "
560+ " %2, %3};"
561+ : " +r" (Dxi[4 ]), " +r" (Dxi[5 ]), " +r" (Dxi[6 ]), " +r" (Dxi[7 ])
562+ : " r" (Axi[0 ]), " r" (Axi[1 ]), " r" (Axi[2 ]), " r" (Axi[3 ]), " r" (Bxi[1 ]), " r" (Bxi[3 ]));
563+ #else
564+ GGML_UNUSED_VARS (D, A, B);
565+ NO_DEVICE_CODE;
566+ #endif // AMPERE_MMA_AVAILABLE
567+ }
568+
464569 static __device__ __forceinline__ void mma (
465570 tile<16 , 16 , float > & D, const tile<16 , 8 , half2> & A, const tile<16 , 8 , half2> & B) {
466571#ifdef TURING_MMA_AVAILABLE
@@ -489,12 +594,48 @@ namespace ggml_cuda_mma {
489594 : " +r" (Dxi[4 ]), " +r" (Dxi[5 ]), " +r" (Dxi[6 ]), " +r" (Dxi[7 ])
490595 : " r" (Axi[2 ]), " r" (Axi[3 ]), " r" (Bxi[3 ]));
491596#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
597+ #elif defined(AMD_WMMA_AVAILABLE)
598+ #if defined(RDNA4)
599+ using halfx8_t = __attribute__ ((ext_vector_type (8 ))) _Float16;
600+ using floatx8_t = __attribute__ ((ext_vector_type (8 ))) float ;
601+ floatx8_t & acc_frag = reinterpret_cast <floatx8_t &>(D.x [0 ]);
602+ const halfx8_t & a_frag = reinterpret_cast <const halfx8_t &>(A.x [0 ]);
603+ const halfx8_t & b_frag = reinterpret_cast <const halfx8_t &>(B.x [0 ]);
604+ acc_frag = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12 (a_frag, b_frag, acc_frag);
605+ #endif // defined(RDNA4)
492606#else
493607 GGML_UNUSED_VARS (D, A, B);
494608 NO_DEVICE_CODE;
495609#endif // TURING_MMA_AVAILABLE
496610 }
497611
612+ static __device__ __forceinline__ void mma (
613+ tile<16 , 16 , float > & D, const tile<16 , 8 , nv_bfloat162> & A, const tile<16 , 8 , nv_bfloat162> & B) {
614+ #ifdef AMPERE_MMA_AVAILABLE
615+ const int * Axi = (const int *) A.x ;
616+ const int * Bxi = (const int *) B.x ;
617+ int * Dxi = (int *) D.x ;
618+ asm (" mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
619+ : " +r" (Dxi[0 ]), " +r" (Dxi[1 ]), " +r" (Dxi[2 ]), " +r" (Dxi[3 ])
620+ : " r" (Axi[0 ]), " r" (Axi[1 ]), " r" (Axi[2 ]), " r" (Axi[3 ]), " r" (Bxi[0 ]), " r" (Bxi[2 ]));
621+ asm (" mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
622+ : " +r" (Dxi[4 ]), " +r" (Dxi[5 ]), " +r" (Dxi[6 ]), " +r" (Dxi[7 ])
623+ : " r" (Axi[0 ]), " r" (Axi[1 ]), " r" (Axi[2 ]), " r" (Axi[3 ]), " r" (Bxi[1 ]), " r" (Bxi[3 ]));
624+ #elif defined(AMD_WMMA_AVAILABLE)
625+ #if defined(RDNA4)
626+ using bf16x8_t = __attribute__ ((ext_vector_type (8 ))) __bf16;
627+ using floatx8_t = __attribute__ ((ext_vector_type (8 ))) float ;
628+ floatx8_t & acc_frag = reinterpret_cast <floatx8_t &>(D.x [0 ]);
629+ const bf16x8_t & a_frag = reinterpret_cast <const bf16x8_t &>(A.x [0 ]);
630+ const bf16x8_t & b_frag = reinterpret_cast <const bf16x8_t &>(B.x [0 ]);
631+ acc_frag = __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12 (a_frag, b_frag, acc_frag);
632+ #endif // defined(RDNA4)
633+ #else
634+ GGML_UNUSED_VARS (D, A, B);
635+ NO_DEVICE_CODE;
636+ #endif // AMPERE_MMA_AVAILABLE
637+ }
638+
498639 static __device__ __forceinline__ void mma (
499640 tile<16 , 16 , int > & D, const tile<16 , 8 , int > & A, const tile<16 , 8 , int > & B) {
500641#if defined(AMD_MFMA_AVAILABLE)
0 commit comments