@@ -127,7 +127,7 @@ namespace ggml_cuda_mma {
127127 static_assert (I == -1 && J == -1 , " template specialization not implemented" );
128128 }
129129 }
130- #endif
130+ #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
131131 };
132132
133133 template <int I_, int J_>
@@ -182,10 +182,16 @@ namespace ggml_cuda_mma {
182182
183183 template <int I, int J, typename T>
184184 static __device__ __forceinline__ void load_generic (tile<I, J, T> & t, const T * __restrict__ xs0, const int stride) {
185+ #if defined(AMD_MFMA_AVAILABLE)
186+ int64_t * xi = (int64_t *) t.x ;
187+ const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx .x % t.I ) * stride + 2 * (threadIdx .x / t.I ));
188+ xi[0 ] = xs[0 ];
189+ #else
185190#pragma unroll
186191 for (int l = 0 ; l < t.ne ; ++l) {
187192 t.x [l] = xs0[t.get_i (l)*stride + t.get_j (l)];
188193 }
194+ #endif // defined(AMD_MFMA_AVAILABLE)
189195 }
190196
191197 template <typename T>
@@ -220,11 +226,7 @@ namespace ggml_cuda_mma {
220226 template <typename T>
221227 static __device__ __forceinline__ void load_ldmatrix (
222228 tile<16 , 8 , T> & t, const T * __restrict__ xs0, const int stride) {
223- #if defined(AMD_MMA_AVAILABLE)
224- int64_t * xi = (int64_t *) t.x ;
225- const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx .x % t.I ) * stride + 2 * (threadIdx .x / t.I ));
226- xi[0 ] = xs[0 ];
227- #elif defined(NEW_MMA_AVAILABLE)
229+ #if defined(NEW_MMA_AVAILABLE)
228230 int * xi = (int * ) t.x ;
229231 const int * xs = (const int *) xs0 + (threadIdx .x % t.I ) * stride + (threadIdx .x / t.I ) * (t.J / 2 );
230232 asm volatile (" ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];"
@@ -235,23 +237,6 @@ namespace ggml_cuda_mma {
235237#endif // NEW_MMA_AVAILABLE
236238 }
237239
238- template <typename T>
239- static __device__ __forceinline__ void load_ldmatrix (
240- tile<32 , 4 , T> & t, const T * __restrict__ xs0, const int stride) {
241- #if defined(AMD_MMA_AVAILABLE)
242- int64_t * xi = (int64_t *) t.x ;
243- const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx .x % t.I ) * stride + 2 * (threadIdx .x / t.I ));
244- xi[0 ] = xs[0 ];
245- #elif defined(NEW_MMA_AVAILABLE)
246- GGML_UNUSED (t);
247- GGML_UNUSED (xs0);
248- GGML_UNUSED (stride);
249- NO_DEVICE_CODE;
250- #else
251- load_generic (t, xs0, stride);
252- #endif // AMD_MMA_AVAILABLE
253- }
254-
255240 template <typename T>
256241 static __device__ __forceinline__ void load_ldmatrix_trans (
257242 tile<16 , 8 , T> & t, const T * __restrict__ xs0, const int stride) {
@@ -451,15 +436,23 @@ namespace ggml_cuda_mma {
451436
452437 static __device__ __forceinline__ void mma (
453438 tile<16 , 16 , int > & D, const tile<16 , 8 , int > & A, const tile<16 , 8 , int > & B) {
454- #if defined(AMD_MMA_AVAILABLE)
455- #if defined(CDNA3)
439+ #if defined(AMD_MFMA_AVAILABLE)
456440 using int32x4_t = __attribute__ ((__vector_size__ (4 * sizeof (int )))) int ;
457- int32x4_t * acc = (int32x4_t *) D.x ;
458- acc[0 ] = __builtin_amdgcn_mfma_i32_16x16x32_i8 (((int64_t *) A.x )[0 ],
459- ((int64_t *) B.x )[0 ],
441+ int32x4_t * acc = (int32x4_t *) D.x ;
442+ #if defined(CDNA3)
443+ acc[0 ] = __builtin_amdgcn_mfma_i32_16x16x32_i8 (((int64_t *) A.x )[0 ],
444+ ((int64_t *) B.x )[0 ],
460445 acc[0 ],
461446 0 , 0 , 0 );
462447#elif defined(CDNA2) || defined(CDNA)
448+ acc[0 ] = __builtin_amdgcn_mfma_i32_16x16x16i8 (A.x [0 ],
449+ B.x [0 ],
450+ acc[0 ],
451+ 0 , 0 , 0 );
452+ acc[0 ] = __builtin_amdgcn_mfma_i32_16x16x16i8 (A.x [1 ],
453+ B.x [1 ],
454+ acc[0 ],
455+ 0 , 0 , 0 );
463456#endif
464457#else
465458 GGML_UNUSED (D);
@@ -471,15 +464,23 @@ namespace ggml_cuda_mma {
471464
472465 static __device__ __forceinline__ void mma (
473466 tile<32 , 32 , int > & D, const tile<32 , 4 , int > & A, const tile<32 , 4 , int > & B) {
474- #if defined(AMD_MMA_AVAILABLE)
475- #if defined(CDNA3)
467+ #if defined(AMD_MFMA_AVAILABLE)
476468 using int32x16_t = __attribute__ ((__vector_size__ (16 * sizeof (int )))) int ;
477- int32x16_t * acc = (int32x16_t *) D.x ;
478- acc[0 ] = __builtin_amdgcn_mfma_i32_32x32x16_i8 (((int64_t *) A.x )[0 ],
479- ((int64_t *) B.x )[0 ],
469+ int32x16_t * acc = (int32x16_t *) D.x ;
470+ #if defined(CDNA3)
471+ acc[0 ] = __builtin_amdgcn_mfma_i32_32x32x16_i8 (((int64_t *) A.x )[0 ],
472+ ((int64_t *) B.x )[0 ],
480473 acc[0 ],
481474 0 , 0 , 0 );
482475#elif defined(CDNA2) || defined(CDNA)
476+ acc[0 ] = __builtin_amdgcn_mfma_i32_32x32x8i8 (A.x [0 ],
477+ B.x [0 ],
478+ acc[0 ],
479+ 0 , 0 , 0 );
480+ acc[0 ] = __builtin_amdgcn_mfma_i32_32x32x8i8 (A.x [1 ],
481+ B.x [1 ],
482+ acc[0 ],
483+ 0 , 0 , 0 );
483484#endif
484485#else
485486 GGML_UNUSED (D);
0 commit comments