1818
1919#include " common.cuh"
2020
21+ // On Volta each warp is doing 4 8x8 mma operations in parallel.
22+ // The basic memory layout for a 32x8 output tile is to stack 4 input tiles in I direction and to mirror the B tile.
23+ // However, the i indices in this file are by default permuted to simplify the index calculations.
24+ // #define GGML_CUDA_MMA_NO_VOLTA_PERM
2125
2226#if CUDART_VERSION >= 11080
2327
@@ -73,6 +77,15 @@ namespace ggml_cuda_mma {
7377 static constexpr int ne = I * J / 64 ;
7478 T x[ne] = {0 };
7579
80+ static constexpr __device__ bool supported () {
81+ if (I == 64 && J == 2 ) return true ;
82+ if (I == 16 && J == 8 ) return true ;
83+ if (I == 32 && J == 4 ) return true ;
84+ if (I == 16 && J == 16 ) return true ;
85+ if (I == 32 && J == 32 ) return true ;
86+ return false ;
87+ }
88+
7689 static __device__ __forceinline__ int get_i (const int l) {
7790 if constexpr (I == 64 && J == 2 ) { // Special tile size to load <16, 4> as <16, 8>
7891 return threadIdx .x % 16 ;
@@ -85,7 +98,8 @@ namespace ggml_cuda_mma {
8598 } else if constexpr (I == 32 && J == 32 ) {
8699 return 4 * (threadIdx .x / 32 ) + 8 * (l / 4 ) + (l % 4 );
87100 } else {
88- static_assert (I == -1 && J == -1 , " template specialization not implemented" );
101+ NO_DEVICE_CODE;
102+ return -1 ;
89103 }
90104 }
91105
@@ -101,36 +115,84 @@ namespace ggml_cuda_mma {
101115 } else if constexpr (I == 32 && J == 32 ) {
102116 return threadIdx .x % 32 ;
103117 } else {
104- static_assert (I == -1 && J == -1 , " template specialization not implemented" );
118+ NO_DEVICE_CODE;
119+ return -1 ;
120+ }
121+ }
122+ #elif __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
123+ static constexpr int ne = I * J / 32 ;
124+ T x[ne] = {0 };
125+
126+ static constexpr __device__ bool supported () {
127+ if (I == 32 && J == 8 ) return true ;
128+ return false ;
129+ }
130+
131+ static __device__ __forceinline__ int get_i (const int l) {
132+ if constexpr (I == 32 && J == 8 ) {
133+ #ifdef GGML_CUDA_MMA_NO_VOLTA_PERM
134+ return (((threadIdx .x % 16 ) / 4 ) * 8 ) | ((threadIdx .x / 16 ) * 4 ) | (l & 2 ) | (threadIdx .x % 2 );
135+ #else
136+ return (l & 2 ) | (threadIdx .x & ~2 );
137+ #endif // GGML_CUDA_MMA_NO_VOLTA_PERM
138+ } else {
139+ NO_DEVICE_CODE;
140+ return -1 ;
141+ }
142+ }
143+
144+ static __device__ __forceinline__ int get_j (const int l) {
145+ if constexpr (I == 32 && J == 8 ) {
146+ return (threadIdx .x & 2 ) | (l & (4 + 1 ));
147+ } else {
148+ NO_DEVICE_CODE;
149+ return -1 ;
105150 }
106151 }
107152#else
108153 static constexpr int ne = I * J / 32 ;
109154 T x[ne] = {0 };
110155
156+ static constexpr __device__ bool supported () {
157+ if (I == 8 && J == 4 ) return true ;
158+ if (I == 8 && J == 8 ) return true ;
159+ if (I == 16 && J == 8 ) return true ;
160+ if (I == 16 && J == 16 ) return true ;
161+ if (I == 32 && J == 8 ) return true ;
162+ return false ;
163+ }
164+
111165 static __device__ __forceinline__ int get_i (const int l) {
112- if constexpr (I == 8 && (J == 4 || J == 8 )) {
166+ if constexpr (I == 8 && J == 4 ) {
167+ return threadIdx .x / 4 ;
168+ } else if constexpr (I == 8 && J == 8 ) {
113169 return threadIdx .x / 4 ;
114170 } else if constexpr (I == 16 && J == 8 ) {
115- return (l / 2 ) * 8 + threadIdx .x / 4 ;
171+ return (( l / 2 ) * 8 ) | ( threadIdx .x / 4 ) ;
116172 } else if constexpr (I == 16 && J == 16 ) {
117- return ((l / 2 ) % 2 ) * 8 + threadIdx .x / 4 ;
173+ return (((l / 2 ) % 2 ) * 8 ) | (threadIdx .x / 4 );
174+ } else if constexpr (I == 32 && J == 8 ) {
175+ return tile<16 , 8 , T>::get_i (l); // Memory layout simply repeated with same pattern in i direction.
118176 } else {
119- static_assert (I == -1 && J == -1 , " template specialization not implemented" );
177+ NO_DEVICE_CODE;
178+ return -1 ;
120179 }
121180 }
122181
123182 static __device__ __forceinline__ int get_j (const int l) {
124183 if constexpr (I == 8 && J == 4 ) {
125184 return threadIdx .x % 4 ;
126185 } else if constexpr (I == 8 && J == 8 ) {
127- return 4 * l + threadIdx .x % 4 ;
186+ return (l * 4 ) | ( threadIdx .x % 4 ) ;
128187 } else if constexpr (I == 16 && J == 8 ) {
129- return 2 * ( threadIdx .x % 4 ) + l % 2 ;
188+ return (( threadIdx .x % 4 ) * 2 ) | ( l % 2 ) ;
130189 } else if constexpr (I == 16 && J == 16 ) {
131- return 8 * (l / 4 ) + 2 * (threadIdx .x % 4 ) + l % 2 ;
190+ return ((l / 4 ) * 8 ) | ((threadIdx .x % 4 ) * 2 ) | (l % 2 );
191+ } else if constexpr (I == 32 && J == 8 ) {
192+ return tile<16 , 8 , T>::get_j (l); // Memory layout simply repeated with same pattern in i direction.
132193 } else {
133- static_assert (I == -1 && J == -1 , " template specialization not implemented" );
194+ NO_DEVICE_CODE;
195+ return -1 ;
134196 }
135197 }
136198#endif // defined(GGML_USE_HIP)
@@ -140,32 +202,83 @@ namespace ggml_cuda_mma {
140202 struct tile <I_, J_, half2> {
141203 static constexpr int I = I_;
142204 static constexpr int J = J_;
205+
206+ #if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
207+ static constexpr int ne = I == 8 && J == 8 ? I * J / (WARP_SIZE/4 ) : I * J / WARP_SIZE;
208+ half2 x[ne] = {{0 .0f , 0 .0f }};
209+
210+ static constexpr __device__ bool supported () {
211+ if (I == 8 && J == 8 ) return true ;
212+ if (I == 32 && J == 8 ) return true ;
213+ return false ;
214+ }
215+
216+ static __device__ __forceinline__ int get_i (const int l) {
217+ if constexpr (I == 8 && J == 8 ) {
218+ return ((threadIdx .x / 16 ) * 4 ) | (threadIdx .x % 4 );
219+ } else if constexpr (I == 32 && J == 8 ) {
220+ #ifdef GGML_CUDA_MMA_NO_VOLTA_PERM
221+ return (((threadIdx .x % 16 ) / 4 ) * 8 ) | ((threadIdx .x / 16 ) * 4 ) | (threadIdx .x % 4 );
222+ #else
223+ return threadIdx .x ;
224+ #endif // GGML_CUDA_MMA_NO_VOLTA_PERM
225+ } else {
226+ NO_DEVICE_CODE;
227+ return -1 ;
228+ }
229+ }
230+
231+ static __device__ __forceinline__ int get_j (const int l) {
232+ if constexpr ((I == 8 || I == 32 ) && J == 8 ) {
233+ return l;
234+ } else {
235+ NO_DEVICE_CODE;
236+ return -1 ;
237+ }
238+ }
239+ #else
143240 static constexpr int ne = I * J / WARP_SIZE;
144241 half2 x[ne] = {{0 .0f , 0 .0f }};
145242
243+ static constexpr __device__ bool supported () {
244+ if (I == 8 && J == 4 ) return true ;
245+ if (I == 8 && J == 8 ) return true ;
246+ if (I == 16 && J == 8 ) return true ;
247+ if (I == 16 && J == 16 ) return true ;
248+ if (I == 32 && J == 8 ) return true ;
249+ return false ;
250+ }
251+
146252 static __device__ __forceinline__ int get_i (const int l) {
147253 if constexpr (I == 8 && J == 8 ) {
148254 return threadIdx .x / 4 ;
149255 } else if constexpr (I == 16 && J == 4 ) {
150- return l * 8 + threadIdx .x / 4 ;
256+ return ( l * 8 ) | ( threadIdx .x / 4 ) ;
151257 } else if constexpr (I == 16 && J == 8 ) {
152- return (l % 2 ) * 8 + threadIdx .x / 4 ;
258+ return ((l % 2 ) * 8 ) | (threadIdx .x / 4 );
259+ } else if constexpr (I == 32 && J == 8 ) {
260+ return ((l / 4 ) * 16 ) | ((l % 2 ) * 8 ) | (threadIdx .x / 4 );
153261 } else {
154- static_assert (I == -1 && J == -1 , " template specialization not implemented" );
262+ NO_DEVICE_CODE;
263+ return -1 ;
155264 }
156265 }
157266
158267 static __device__ __forceinline__ int get_j (const int l) {
159268 if constexpr (I == 8 && J == 8 ) {
160- return l * 4 + threadIdx .x % 4 ;
269+ return ( l * 4 ) | ( threadIdx .x % 4 ) ;
161270 } else if constexpr (I == 16 && J == 4 ) {
162271 return threadIdx .x % 4 ;
163272 } else if constexpr (I == 16 && J == 8 ) {
164- return (l / 2 ) * 4 + threadIdx .x % 4 ;
273+ return ((l / 2 ) * 4 ) | (threadIdx .x % 4 );
274+ } else if constexpr (I == 32 && J == 8 ) {
275+ return ((l & 2 ) * 2 ) | (threadIdx .x % 4 );
165276 } else {
166- static_assert (I == -1 && J == -1 , " template specialization not implemented" );
277+ NO_DEVICE_CODE;
278+ return -1 ;
167279 }
168280 }
281+ #endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
169282 };
170283
171284 template <int I_, int J_>
@@ -175,27 +288,36 @@ namespace ggml_cuda_mma {
175288 static constexpr int ne = I * J / WARP_SIZE;
176289 nv_bfloat162 x[ne] = {{0 .0f , 0 .0f }};
177290
291+ static constexpr __device__ bool supported () {
292+ if (I == 8 && J == 8 ) return true ;
293+ if (I == 16 && J == 4 ) return true ;
294+ if (I == 16 && J == 8 ) return true ;
295+ return false ;
296+ }
297+
178298 static __device__ __forceinline__ int get_i (const int l) {
179299 if constexpr (I == 8 && J == 8 ) {
180300 return threadIdx .x / 4 ;
181301 } else if constexpr (I == 16 && J == 4 ) {
182- return l * 8 + threadIdx .x / 4 ;
302+ return ( l * 8 ) | ( threadIdx .x / 4 ) ;
183303 } else if constexpr (I == 16 && J == 8 ) {
184- return (l % 2 ) * 8 + threadIdx .x / 4 ;
304+ return (( l % 2 ) * 8 ) | ( threadIdx .x / 4 ) ;
185305 } else {
186- static_assert (I == -1 && J == -1 , " template specialization not implemented" );
306+ NO_DEVICE_CODE;
307+ return -1 ;
187308 }
188309 }
189310
190311 static __device__ __forceinline__ int get_j (const int l) {
191312 if constexpr (I == 8 && J == 8 ) {
192- return l * 4 + threadIdx .x % 4 ;
313+ return ( l * 4 ) | ( threadIdx .x % 4 ) ;
193314 } else if constexpr (I == 16 && J == 4 ) {
194315 return threadIdx .x % 4 ;
195316 } else if constexpr (I == 16 && J == 8 ) {
196- return (l / 2 ) * 4 + threadIdx .x % 4 ;
317+ return (( l / 2 ) * 4 ) | ( threadIdx .x % 4 ) ;
197318 } else {
198- static_assert (I == -1 && J == -1 , " template specialization not implemented" );
319+ NO_DEVICE_CODE;
320+ return -1 ;
199321 }
200322 }
201323 };
@@ -263,8 +385,12 @@ namespace ggml_cuda_mma {
263385 : " =r" (xi[0 ]), " =r" (xi[1 ])
264386 : " l" (xs));
265387#else
266- load_generic (xs0, stride);
267- GGML_UNUSED (t);
388+ #if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
389+ GGML_UNUSED_VARS (t, xs0, stride);
390+ NO_DEVICE_CODE;
391+ #else
392+ load_generic (t, xs0, stride);
393+ #endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
268394#endif // TURING_MMA_AVAILABLE
269395 }
270396
@@ -277,11 +403,35 @@ namespace ggml_cuda_mma {
277403 asm volatile (" ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];"
278404 : " =r" (xi[0 ]), " =r" (xi[1 ]), " =r" (xi[2 ]), " =r" (xi[3 ])
279405 : " l" (xs));
406+ #else
407+ #if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
408+ GGML_UNUSED_VARS (t, xs0, stride);
409+ NO_DEVICE_CODE;
280410#else
281411 load_generic (t, xs0, stride);
412+ #endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
282413#endif // TURING_MMA_AVAILABLE
283414 }
284415
416+ template <typename T>
417+ static __device__ __forceinline__ void load_ldmatrix (
418+ tile<32 , 8 , T> & t, const T * __restrict__ xs0, const int stride) {
419+ #if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
420+ #if 1
421+ // TODO: more generic handling
422+ static_assert (sizeof (T) == 4 , " bad type size" );
423+ ggml_cuda_memcpy_1<4 *sizeof (T)>(t.x + 0 , xs0 + t.get_i (0 )*stride + 0 );
424+ ggml_cuda_memcpy_1<4 *sizeof (T)>(t.x + 4 , xs0 + t.get_i (4 )*stride + 4 );
425+ #else
426+ load_generic(t, xs0, stride);
427+ #endif // 1
428+ #else
429+ tile<16 , 8 , T> * t16 = (tile<16 , 8 , T> *) &t;
430+ load_ldmatrix (t16[0 ], xs0 + 0 *stride, stride);
431+ load_ldmatrix (t16[1 ], xs0 + 16 *stride, stride);
432+ #endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
433+ }
434+
285435 template <typename T>
286436 static __device__ __forceinline__ void load_ldmatrix_trans (
287437 tile<16 , 8 , T> & t, const T * __restrict__ xs0, const int stride) {
@@ -546,4 +696,43 @@ namespace ggml_cuda_mma {
546696 NO_DEVICE_CODE;
547697#endif // AMD_MFMA_AVAILABLE
548698 }
699+
700+ template <typename T1, typename T2, int J, int K>
701+ static __device__ __forceinline__ void mma (
702+ tile<32 , J, T1> & D, const tile<32 , K, T2> & A, const tile<J, K, T2> & B) {
703+ tile<16 , J, T1> * D16 = (tile<16 , J, T1> *) &D;
704+ tile<16 , K, T2> * A16 = (tile<16 , K, T2> *) &A;
705+ mma (D16[0 ], A16[0 ], B);
706+ mma (D16[1 ], A16[1 ], B);
707+ }
708+
709+ static __device__ __forceinline__ void mma (
710+ tile<32 , 8 , float > & D, const tile<32 , 8 , half2> & A, const tile<8 , 8 , half2> & B) {
711+ #if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
712+ const int * Axi = (const int *) A.x ;
713+ const int * Bxi = (const int *) B.x ;
714+ int * Dxi = (int *) D.x ;
715+ asm (" mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 "
716+ " {%0, %1, %2, %3, %4, %5, %6, %7}, {%8, %9}, {%10, %11}, {%0, %1, %2, %3, %4, %5, %6, %7};"
717+ : " +r" (Dxi[0 ]), " +r" (Dxi[1 ]), " +r" (Dxi[2 ]), " +r" (Dxi[3 ]), " +r" (Dxi[4 ]), " +r" (Dxi[5 ]), " +r" (Dxi[6 ]), " +r" (Dxi[7 ])
718+ : " r" (Axi[0 ]), " r" (Axi[1 ]), " r" (Bxi[0 ]), " r" (Bxi[1 ]));
719+ asm (" mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 "
720+ " {%0, %1, %2, %3, %4, %5, %6, %7}, {%8, %9}, {%10, %11}, {%0, %1, %2, %3, %4, %5, %6, %7};"
721+ : " +r" (Dxi[0 ]), " +r" (Dxi[1 ]), " +r" (Dxi[2 ]), " +r" (Dxi[3 ]), " +r" (Dxi[4 ]), " +r" (Dxi[5 ]), " +r" (Dxi[6 ]), " +r" (Dxi[7 ])
722+ : " r" (Axi[2 ]), " r" (Axi[3 ]), " r" (Bxi[2 ]), " r" (Bxi[3 ]));
723+ asm (" mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 "
724+ " {%0, %1, %2, %3, %4, %5, %6, %7}, {%8, %9}, {%10, %11}, {%0, %1, %2, %3, %4, %5, %6, %7};"
725+ : " +r" (Dxi[0 ]), " +r" (Dxi[1 ]), " +r" (Dxi[2 ]), " +r" (Dxi[3 ]), " +r" (Dxi[4 ]), " +r" (Dxi[5 ]), " +r" (Dxi[6 ]), " +r" (Dxi[7 ])
726+ : " r" (Axi[4 ]), " r" (Axi[5 ]), " r" (Bxi[4 ]), " r" (Bxi[5 ]));
727+ asm (" mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 "
728+ " {%0, %1, %2, %3, %4, %5, %6, %7}, {%8, %9}, {%10, %11}, {%0, %1, %2, %3, %4, %5, %6, %7};"
729+ : " +r" (Dxi[0 ]), " +r" (Dxi[1 ]), " +r" (Dxi[2 ]), " +r" (Dxi[3 ]), " +r" (Dxi[4 ]), " +r" (Dxi[5 ]), " +r" (Dxi[6 ]), " +r" (Dxi[7 ])
730+ : " r" (Axi[6 ]), " r" (Axi[7 ]), " r" (Bxi[6 ]), " r" (Bxi[7 ]));
731+ #else
732+ tile<16 , 8 , float > * D16 = (tile<16 , 8 , float > *) &D;
733+ tile<16 , 8 , half2> * A16 = (tile<16 , 8 , half2> *) &A;
734+ mma (D16[0 ], A16[0 ], B);
735+ mma (D16[1 ], A16[1 ], B);
736+ #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
737+ }
549738}
0 commit comments