1818
1919#include " common.cuh"
2020
21+ // On Volta each warp is doing 4 8x8 mma operations in parallel.
22+ // The basic memory layout for a 32x8 output tile is to stack 4 input tiles in I direction and to mirror the B tile.
23+ // However, the i indices in this file are by default permuted to simplify the index calculations.
24+ // #define GGML_CUDA_MMA_NO_VOLTA_PERM
2125
2226#if CUDART_VERSION >= 11080
2327
@@ -86,6 +90,7 @@ namespace ggml_cuda_mma {
8690 return 4 * (threadIdx .x / 32 ) + 8 * (l / 4 ) + (l % 4 );
8791 } else {
8892 static_assert (I == -1 && J == -1 , " template specialization not implemented" );
93+ return -1 ;
8994 }
9095 }
9196
@@ -102,6 +107,32 @@ namespace ggml_cuda_mma {
102107 return threadIdx .x % 32 ;
103108 } else {
104109 static_assert (I == -1 && J == -1 , " template specialization not implemented" );
110+ return -1 ;
111+ }
112+ }
113+ #elif __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
114+ static constexpr int ne = I * J / 32 ;
115+ T x[ne] = {0 };
116+
117+ static __device__ __forceinline__ int get_i (const int l) {
118+ if constexpr (I == 32 && J == 8 ) {
119+ #ifdef GGML_CUDA_MMA_NO_VOLTA_PERM
120+ return (((threadIdx .x % 16 ) / 4 ) * 8 ) | ((threadIdx .x / 16 ) * 4 ) | (l & 2 ) | (threadIdx .x % 2 );
121+ #else
122+ return (l & 2 ) | (threadIdx .x & ~2 );
123+ #endif // GGML_CUDA_MMA_NO_VOLTA_PERM
124+ } else {
125+ static_assert (I == -1 && J == -1 , " template specialization not implemented" );
126+ return -1 ;
127+ }
128+ }
129+
130+ static __device__ __forceinline__ int get_j (const int l) {
131+ if constexpr (I == 32 && J == 8 ) {
132+ return (threadIdx .x & 2 ) | (l & (4 + 1 ));
133+ } else {
134+ static_assert (I == -1 && J == -1 , " template specialization not implemented" );
135+ return -1 ;
105136 }
106137 }
107138#else
@@ -111,26 +142,28 @@ namespace ggml_cuda_mma {
111142 static __device__ __forceinline__ int get_i (const int l) {
112143 if constexpr (I == 8 && (J == 4 || J == 8 )) {
113144 return threadIdx .x / 4 ;
114- } else if constexpr (I == 16 && J == 8 ) {
115- return (l / 2 ) * 8 + threadIdx .x / 4 ;
145+ } else if constexpr (( I == 16 || I == 32 ) && J == 8 ) {
146+ return (( l / 2 ) * 8 ) | ( threadIdx .x / 4 ) ;
116147 } else if constexpr (I == 16 && J == 16 ) {
117- return ((l / 2 ) % 2 ) * 8 + threadIdx .x / 4 ;
148+ return ((( l / 2 ) % 2 ) * 8 ) | ( threadIdx .x / 4 ) ;
118149 } else {
119150 static_assert (I == -1 && J == -1 , " template specialization not implemented" );
151+ return -1 ;
120152 }
121153 }
122154
123155 static __device__ __forceinline__ int get_j (const int l) {
124156 if constexpr (I == 8 && J == 4 ) {
125157 return threadIdx .x % 4 ;
126158 } else if constexpr (I == 8 && J == 8 ) {
127- return 4 * l + threadIdx .x % 4 ;
128- } else if constexpr (I == 16 && J == 8 ) {
129- return 2 * ( threadIdx .x % 4 ) + l % 2 ;
159+ return (l * 4 ) | ( threadIdx .x % 4 ) ;
160+ } else if constexpr (( I == 16 || I == 32 ) && J == 8 ) {
161+ return (( threadIdx .x % 4 ) * 2 ) | ( l % 2 ) ;
130162 } else if constexpr (I == 16 && J == 16 ) {
131- return 8 * ( l / 4 ) + 2 * ( threadIdx .x % 4 ) + l % 2 ;
163+ return (( l / 4 ) * 8 ) | (( threadIdx .x % 4 ) * 2 ) | ( l % 2 ) ;
132164 } else {
133165 static_assert (I == -1 && J == -1 , " template specialization not implemented" );
166+ return -1 ;
134167 }
135168 }
136169#endif // defined(GGML_USE_HIP)
@@ -140,32 +173,68 @@ namespace ggml_cuda_mma {
140173 struct tile <I_, J_, half2> {
141174 static constexpr int I = I_;
142175 static constexpr int J = J_;
176+
177+ #if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
178+ static constexpr int ne = I == 8 && J == 8 ? I * J / (WARP_SIZE/4 ) : I * J / WARP_SIZE;
179+ half2 x[ne] = {{0 .0f , 0 .0f }};
180+
181+ static __device__ __forceinline__ int get_i (const int l) {
182+ if constexpr (I == 8 && J == 8 ) {
183+ return ((threadIdx .x / 16 ) * 4 ) | (threadIdx .x % 4 );
184+ } else if constexpr (I == 32 && J == 8 ) {
185+ #ifdef GGML_CUDA_MMA_NO_VOLTA_PERM
186+ return (((threadIdx .x % 16 ) / 4 ) * 8 ) | ((threadIdx .x / 16 ) * 4 ) | (threadIdx .x % 4 );
187+ #else
188+ return threadIdx .x ;
189+ #endif // GGML_CUDA_MMA_NO_VOLTA_PERM
190+ } else {
191+ static_assert (I == -1 && J == -1 , " template specialization not implemented" );
192+ return -1 ;
193+ }
194+ }
195+
196+ static __device__ __forceinline__ int get_j (const int l) {
197+ if constexpr ((I == 8 || I == 32 ) && J == 8 ) {
198+ return l;
199+ } else {
200+ static_assert (I == -1 && J == -1 , " template specialization not implemented" );
201+ return -1 ;
202+ }
203+ }
204+ #else
143205 static constexpr int ne = I * J / WARP_SIZE;
144206 half2 x[ne] = {{0 .0f , 0 .0f }};
145207
146208 static __device__ __forceinline__ int get_i (const int l) {
147209 if constexpr (I == 8 && J == 8 ) {
148210 return threadIdx .x / 4 ;
149211 } else if constexpr (I == 16 && J == 4 ) {
150- return l * 8 + threadIdx .x / 4 ;
212+ return ( l * 8 ) | ( threadIdx .x / 4 ) ;
151213 } else if constexpr (I == 16 && J == 8 ) {
152- return (l % 2 ) * 8 + threadIdx .x / 4 ;
214+ return ((l % 2 ) * 8 ) | (threadIdx .x / 4 );
215+ } else if constexpr (I == 32 && J == 8 ) {
216+ return ((l / 4 ) * 16 ) | ((l % 2 ) * 8 ) | (threadIdx .x / 4 );
153217 } else {
154218 static_assert (I == -1 && J == -1 , " template specialization not implemented" );
219+ return -1 ;
155220 }
156221 }
157222
158223 static __device__ __forceinline__ int get_j (const int l) {
159224 if constexpr (I == 8 && J == 8 ) {
160- return l * 4 + threadIdx .x % 4 ;
225+ return ( l * 4 ) | ( threadIdx .x % 4 ) ;
161226 } else if constexpr (I == 16 && J == 4 ) {
162227 return threadIdx .x % 4 ;
163228 } else if constexpr (I == 16 && J == 8 ) {
164- return (l / 2 ) * 4 + threadIdx .x % 4 ;
229+ return ((l / 2 ) * 4 ) | (threadIdx .x % 4 );
230+ } else if constexpr (I == 32 && J == 8 ) {
231+ return ((l & 2 ) * 2 ) | (threadIdx .x % 4 );
165232 } else {
166233 static_assert (I == -1 && J == -1 , " template specialization not implemented" );
234+ return -1 ;
167235 }
168236 }
237+ #endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
169238 };
170239
171240 template <int I_, int J_>
@@ -179,23 +248,25 @@ namespace ggml_cuda_mma {
179248 if constexpr (I == 8 && J == 8 ) {
180249 return threadIdx .x / 4 ;
181250 } else if constexpr (I == 16 && J == 4 ) {
182- return l * 8 + threadIdx .x / 4 ;
251+ return ( l * 8 ) | ( threadIdx .x / 4 ) ;
183252 } else if constexpr (I == 16 && J == 8 ) {
184- return (l % 2 ) * 8 + threadIdx .x / 4 ;
253+ return (( l % 2 ) * 8 ) | ( threadIdx .x / 4 ) ;
185254 } else {
186255 static_assert (I == -1 && J == -1 , " template specialization not implemented" );
256+ return -1 ;
187257 }
188258 }
189259
190260 static __device__ __forceinline__ int get_j (const int l) {
191261 if constexpr (I == 8 && J == 8 ) {
192- return l * 4 + threadIdx .x % 4 ;
262+ return ( l * 4 ) | ( threadIdx .x % 4 ) ;
193263 } else if constexpr (I == 16 && J == 4 ) {
194264 return threadIdx .x % 4 ;
195265 } else if constexpr (I == 16 && J == 8 ) {
196- return (l / 2 ) * 4 + threadIdx .x % 4 ;
266+ return (( l / 2 ) * 4 ) | ( threadIdx .x % 4 ) ;
197267 } else {
198268 static_assert (I == -1 && J == -1 , " template specialization not implemented" );
269+ return -1 ;
199270 }
200271 }
201272 };
@@ -263,8 +334,12 @@ namespace ggml_cuda_mma {
263334 : " =r" (xi[0 ]), " =r" (xi[1 ])
264335 : " l" (xs));
265336#else
266- load_generic (xs0, stride);
267- GGML_UNUSED (t);
337+ #if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
338+ GGML_UNUSED_VARS (t, xs0, stride);
339+ NO_DEVICE_CODE;
340+ #else
341+ load_generic (t, xs0, stride);
342+ #endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
268343#endif // TURING_MMA_AVAILABLE
269344 }
270345
@@ -277,11 +352,35 @@ namespace ggml_cuda_mma {
277352 asm volatile (" ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];"
278353 : " =r" (xi[0 ]), " =r" (xi[1 ]), " =r" (xi[2 ]), " =r" (xi[3 ])
279354 : " l" (xs));
355+ #else
356+ #if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
357+ GGML_UNUSED_VARS (t, xs0, stride);
358+ NO_DEVICE_CODE;
280359#else
281360 load_generic (t, xs0, stride);
361+ #endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
282362#endif // TURING_MMA_AVAILABLE
283363 }
284364
365+ template <typename T>
366+ static __device__ __forceinline__ void load_ldmatrix (
367+ tile<32 , 8 , T> & t, const T * __restrict__ xs0, const int stride) {
368+ #if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
369+ #if 1
370+ // TODO: more generic handling
371+ static_assert (sizeof (T) == 4 , " bad type size" );
372+ ggml_cuda_memcpy_1<4 *sizeof (T)>(t.x + 0 , xs0 + t.get_i (0 )*stride + 0 );
373+ ggml_cuda_memcpy_1<4 *sizeof (T)>(t.x + 4 , xs0 + t.get_i (4 )*stride + 4 );
374+ #else
375+ load_generic(t, xs0, stride);
376+ #endif // 1
377+ #else
378+ tile<16 , 8 , T> * t16 = (tile<16 , 8 , T> *) &t;
379+ load_ldmatrix (t16[0 ], xs0 + 0 *stride, stride);
380+ load_ldmatrix (t16[1 ], xs0 + 16 *stride, stride);
381+ #endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
382+ }
383+
285384 template <typename T>
286385 static __device__ __forceinline__ void load_ldmatrix_trans (
287386 tile<16 , 8 , T> & t, const T * __restrict__ xs0, const int stride) {
@@ -546,4 +645,43 @@ namespace ggml_cuda_mma {
546645 NO_DEVICE_CODE;
547646#endif // AMD_MFMA_AVAILABLE
548647 }
648+
649+ template <typename T1, typename T2, int J, int K>
650+ static __device__ __forceinline__ void mma (
651+ tile<32 , J, T1> & D, const tile<32 , K, T2> & A, const tile<J, K, T2> & B) {
652+ tile<16 , J, T1> * D16 = (tile<16 , J, T1> *) &D;
653+ tile<16 , K, T2> * A16 = (tile<16 , K, T2> *) &A;
654+ mma (D16[0 ], A16[0 ], B);
655+ mma (D16[1 ], A16[1 ], B);
656+ }
657+
658+ static __device__ __forceinline__ void mma (
659+ tile<32 , 8 , float > & D, const tile<32 , 8 , half2> & A, const tile<8 , 8 , half2> & B) {
660+ #if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
661+ const int * Axi = (const int *) A.x ;
662+ const int * Bxi = (const int *) B.x ;
663+ int * Dxi = (int *) D.x ;
664+ asm (" mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 "
665+ " {%0, %1, %2, %3, %4, %5, %6, %7}, {%8, %9}, {%10, %11}, {%0, %1, %2, %3, %4, %5, %6, %7};"
666+ : " +r" (Dxi[0 ]), " +r" (Dxi[1 ]), " +r" (Dxi[2 ]), " +r" (Dxi[3 ]), " +r" (Dxi[4 ]), " +r" (Dxi[5 ]), " +r" (Dxi[6 ]), " +r" (Dxi[7 ])
667+ : " r" (Axi[0 ]), " r" (Axi[1 ]), " r" (Bxi[0 ]), " r" (Bxi[1 ]));
668+ asm (" mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 "
669+ " {%0, %1, %2, %3, %4, %5, %6, %7}, {%8, %9}, {%10, %11}, {%0, %1, %2, %3, %4, %5, %6, %7};"
670+ : " +r" (Dxi[0 ]), " +r" (Dxi[1 ]), " +r" (Dxi[2 ]), " +r" (Dxi[3 ]), " +r" (Dxi[4 ]), " +r" (Dxi[5 ]), " +r" (Dxi[6 ]), " +r" (Dxi[7 ])
671+ : " r" (Axi[2 ]), " r" (Axi[3 ]), " r" (Bxi[2 ]), " r" (Bxi[3 ]));
672+ asm (" mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 "
673+ " {%0, %1, %2, %3, %4, %5, %6, %7}, {%8, %9}, {%10, %11}, {%0, %1, %2, %3, %4, %5, %6, %7};"
674+ : " +r" (Dxi[0 ]), " +r" (Dxi[1 ]), " +r" (Dxi[2 ]), " +r" (Dxi[3 ]), " +r" (Dxi[4 ]), " +r" (Dxi[5 ]), " +r" (Dxi[6 ]), " +r" (Dxi[7 ])
675+ : " r" (Axi[4 ]), " r" (Axi[5 ]), " r" (Bxi[4 ]), " r" (Bxi[5 ]));
676+ asm (" mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 "
677+ " {%0, %1, %2, %3, %4, %5, %6, %7}, {%8, %9}, {%10, %11}, {%0, %1, %2, %3, %4, %5, %6, %7};"
678+ : " +r" (Dxi[0 ]), " +r" (Dxi[1 ]), " +r" (Dxi[2 ]), " +r" (Dxi[3 ]), " +r" (Dxi[4 ]), " +r" (Dxi[5 ]), " +r" (Dxi[6 ]), " +r" (Dxi[7 ])
679+ : " r" (Axi[6 ]), " r" (Axi[7 ]), " r" (Bxi[6 ]), " r" (Bxi[7 ]));
680+ #else
681+ tile<16 , 8 , float > * D16 = (tile<16 , 8 , float > *) &D;
682+ tile<16 , 8 , half2> * A16 = (tile<16 , 8 , half2> *) &A;
683+ mma (D16[0 ], A16[0 ], B);
684+ mma (D16[1 ], A16[1 ], B);
685+ #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
686+ }
549687}
0 commit comments