@@ -70,7 +70,30 @@ namespace ggml_cuda_mma {
7070 static constexpr int J = J_;
7171
7272#if defined(GGML_USE_HIP)
73- #if defined(CDNA)
73+ #if defined(RDNA4)
74+ static constexpr int ne = I * J / 32 ;
75+ T x[ne] = {0 };
76+
77+ static __device__ __forceinline__ int get_i (const int l) {
78+ if constexpr (I == 16 && J == 16 ) {
79+ return 8 * (threadIdx .x / 16 ) + l;
80+ } else if constexpr (I == 16 && J == 8 ) { // dummy shape to handle TF32, just make the compiler happle, don't use it in the real case
81+ return 4 * (threadIdx .x / 16 ) + l;
82+ } else {
83+ static_assert (I == -1 && J == -1 , " template specialization not implemented" );
84+ }
85+ }
86+
87+ static __device__ __forceinline__ int get_j (const int l) {
88+ if constexpr (I == 16 && J == 16 ) {
89+ return threadIdx .x % 16 ;
90+ } else if constexpr (I == 16 && J == 8 ) { // dummy shape to handle TF32, just make the compiler happle, don't use it in the real case
91+ return threadIdx .x % 16 ;
92+ } else {
93+ static_assert (I == -1 && J == -1 , " template specialization not implemented" );
94+ }
95+ }
96+ #else
7497 static constexpr int ne = I * J / 64 ;
7598 T x[ne] = {0 };
7699
@@ -105,30 +128,7 @@ namespace ggml_cuda_mma {
105128 static_assert (I == -1 && J == -1 , " template specialization not implemented" );
106129 }
107130 }
108- #elif defined(RDNA4)
109- static constexpr int ne = I * J / 32 ;
110- T x[ne] = {0 };
111-
112- static __device__ __forceinline__ int get_i (const int l) {
113- if constexpr (I == 16 && J == 16 ) {
114- return 8 * (threadIdx .x / 16 ) + l;
115- } else if constexpr (I == 16 && J == 8 ) { // dummy shape to handle TF32, just make the compiler happle, don't use it in the real case
116- return 4 * (threadIdx .x / 16 ) + l;
117- } else {
118- static_assert (I == -1 && J == -1 , " template specialization not implemented" );
119- }
120- }
121-
122- static __device__ __forceinline__ int get_j (const int l) {
123- if constexpr (I == 16 && J == 16 ) {
124- return threadIdx .x % 16 ;
125- } else if constexpr (I == 16 && J == 8 ) { // dummy shape to handle TF32, just make the compiler happle, don't use it in the real case
126- return threadIdx .x % 16 ;
127- } else {
128- static_assert (I == -1 && J == -1 , " template specialization not implemented" );
129- }
130- }
131- #endif // defined(CDNA)
131+ #endif // defined(RDNA4)
132132#else
133133 static constexpr int ne = I * J / 32 ;
134134 T x[ne] = {0 };
0 commit comments