@@ -73,34 +73,7 @@ namespace ggml_cuda_mma {
7373 static constexpr int I = I_;
7474 static constexpr int J = J_;
7575
76- #if defined(GGML_USE_HIP)
77- #if defined(RDNA4)
78- static constexpr int ne = I * J / 32 ;
79- T x[ne] = {0 };
80-
81- static constexpr __device__ bool supported () {
82- if (I == 16 && J == 16 ) return true ;
83- return false ;
84- }
85-
86- static __device__ __forceinline__ int get_i (const int l) {
87- if constexpr (I == 16 && J == 16 ) {
88- return 8 * (threadIdx .x / 16 ) + l;
89- } else {
90- NO_DEVICE_CODE;
91- return -1 ;
92- }
93- }
94-
95- static __device__ __forceinline__ int get_j (const int l) {
96- if constexpr (I == 16 && J == 16 ) {
97- return threadIdx .x % 16 ;
98- } else {
99- NO_DEVICE_CODE;
100- return -1 ;
101- }
102- }
103- #else
76+ #if defined(AMD_MFMA_AVAILABLE)
10477 static constexpr int ne = I * J / 64 ;
10578 T x[ne] = {0 };
10679
@@ -146,7 +119,6 @@ namespace ggml_cuda_mma {
146119 return -1 ;
147120 }
148121 }
149- #endif // defined(RDNA4)
150122#elif __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
151123 static constexpr int ne = I * J / 32 ;
152124 T x[ne] = {0 };
@@ -177,6 +149,34 @@ namespace ggml_cuda_mma {
177149 return -1 ;
178150 }
179151 }
152+ #elif defined(AMD_WMMA_AVAILABLE)
153+ #if defined(RDNA4)
154+ static constexpr int ne = I * J / 32 ;
155+ T x[ne] = {0 };
156+
157+ static constexpr __device__ bool supported () {
158+ if (I == 16 && J == 16 ) return true ;
159+ return false ;
160+ }
161+
162+ static __device__ __forceinline__ int get_i (const int l) {
163+ if constexpr (I == 16 && J == 16 ) {
164+ return 8 * (threadIdx .x / 16 ) + l;
165+ } else {
166+ NO_DEVICE_CODE;
167+ return -1 ;
168+ }
169+ }
170+
171+ static __device__ __forceinline__ int get_j (const int l) {
172+ if constexpr (I == 16 && J == 16 ) {
173+ return threadIdx .x % 16 ;
174+ } else {
175+ NO_DEVICE_CODE;
176+ return -1 ;
177+ }
178+ }
179+ #endif
180180#else
181181 static constexpr int ne = I * J / 32 ;
182182 T x[ne] = {0 };
@@ -437,7 +437,20 @@ namespace ggml_cuda_mma {
437437 xi[0 ] = xs[0 ];
438438 }
439439#elif defined(AMD_WMMA_AVAILABLE)
440- ggml_cuda_memcpy_1<sizeof (t.x )>(t.x , xs0 + t.get_i (0 ) * stride + t.get_j (0 ));
440+ if constexpr (I == 16 && J == 4 ) {
441+ int64_t * xi = (int64_t *) t.x ;
442+ const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx .x % t.I ) * stride + 2 * (threadIdx .x / t.I ));
443+ xi[0 ] = xs[0 ];
444+ }else if constexpr (I == 16 && J == 8 ) {
445+ int64_t * xi = (int64_t *) t.x ;
446+ const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx .x % t.I ) * stride + 4 * (threadIdx .x / t.I ));
447+ xi[0 ] = xs[0 ];
448+
449+ const int64_t * xs1 = (int64_t *) ((const int *) xs0 + (threadIdx .x % t.I ) * stride + 4 * (threadIdx .x / t.I ) + 2 );
450+ xi[1 ] = xs1[0 ];
451+ }else {
452+ NO_DEVICE_CODE;
453+ }
441454#else
442455#pragma unroll
443456 for (int l = 0 ; l < t.ne ; ++l) {
@@ -772,6 +785,36 @@ namespace ggml_cuda_mma {
772785 acc[0 ],
773786 0 , 0 , 0 );
774787#endif // defined(CDNA3)
788+
789+ #elif defined(AMD_WMMA_AVAILABLE)
790+ using int32x2_t = __attribute__ ((__vector_size__ (2 * sizeof (int )))) int ;
791+ int32x2_t * a_vec = (int32x2_t *) A.x ;
792+ int32x2_t * b_vec = (int32x2_t *) B.x ;
793+
794+ using int32x8_t = __attribute__ ((__vector_size__ (8 * sizeof (int )))) int ;
795+ int32x8_t * acc = (int32x8_t *) D.x ;
796+
797+ #if defined(RDNA4)
798+
799+ acc[0 ] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12 (
800+ true ,
801+ a_vec[0 ],
802+ true ,
803+ b_vec[0 ],
804+ acc[0 ],
805+ true
806+ );
807+
808+ acc[0 ] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12 (
809+ true ,
810+ a_vec[1 ],
811+ true ,
812+ b_vec[1 ],
813+ acc[0 ],
814+ true
815+ );
816+ #endif // defined(RDNA4)
817+
775818#else
776819 GGML_UNUSED_VARS (D, A, B);
777820 NO_DEVICE_CODE;
@@ -798,6 +841,7 @@ namespace ggml_cuda_mma {
798841 acc[0 ],
799842 0 , 0 , 0 );
800843#endif // defined(CDNA3)
844+
801845#else
802846 GGML_UNUSED_VARS (D, A, B);
803847 NO_DEVICE_CODE;
@@ -842,4 +886,31 @@ namespace ggml_cuda_mma {
842886 mma (D16[1 ], A16[1 ], B);
843887#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
844888 }
889+
890+ static __device__ __forceinline__ void mma (
891+ tile<16 , 16 , int > & D, const tile<16 , 4 , int > & A, const tile<16 , 4 , int > & B) {
892+ #if defined(AMD_WMMA_AVAILABLE)
893+ using int32x2_t = __attribute__ ((__vector_size__ (2 * sizeof (int )))) int ;
894+ int32x2_t * a_vec = (int32x2_t *) A.x ;
895+ int32x2_t * b_vec = (int32x2_t *) B.x ;
896+
897+ using int32x8_t = __attribute__ ((__vector_size__ (8 * sizeof (int )))) int ;
898+ int32x8_t * acc = (int32x8_t *) D.x ;
899+
900+ acc[0 ] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12 (
901+ true ,
902+ a_vec[0 ],
903+ true ,
904+ b_vec[0 ],
905+ acc[0 ],
906+ false
907+ );
908+ #else
909+ GGML_UNUSED (D);
910+ GGML_UNUSED (A);
911+ GGML_UNUSED (B);
912+ NO_DEVICE_CODE;
913+ #endif
914+ }
845915}
916+
0 commit comments