1818
1919#include  " common.cuh" 
2020
21+ //  On Volta each warp is doing 4 8x8 mma operations in parallel.
22+ //  The basic memory layout for a 32x8 output tile is to stack 4 input tiles in I direction and to mirror the B tile.
23+ //  However, the i indices in this file are by default permuted to simplify the index calculations.
24+ //  #define GGML_CUDA_MMA_NO_VOLTA_PERM
2125
2226#if  CUDART_VERSION >= 11080
2327
@@ -73,6 +77,15 @@ namespace ggml_cuda_mma {
7377        static  constexpr  int  ne = I * J / 64 ;
7478        T x[ne] = {0 };
7579
80+         static  constexpr  __device__  bool  supported () {
81+             if  (I == 64  && J ==  2 ) return  true ;
82+             if  (I == 16  && J ==  8 ) return  true ;
83+             if  (I == 32  && J ==  4 ) return  true ;
84+             if  (I == 16  && J == 16 ) return  true ;
85+             if  (I == 32  && J == 32 ) return  true ;
86+             return  false ;
87+         }
88+ 
7689        static  __device__  __forceinline__  int  get_i (const  int  l) {
7790            if  constexpr  (I == 64  && J == 2 ) { //  Special tile size to load <16, 4> as <16, 8>
7891                return  threadIdx .x  % 16 ;
@@ -85,7 +98,8 @@ namespace ggml_cuda_mma {
8598            } else  if  constexpr  (I == 32  && J == 32 ) {
8699                return  4  * (threadIdx .x  / 32 ) + 8  * (l / 4 ) + (l % 4 );
87100            } else  {
88-                 static_assert (I == -1  && J == -1 , " template specialization not implemented"  );
101+                 NO_DEVICE_CODE;
102+                 return  -1 ;
89103            }
90104        }
91105
@@ -101,36 +115,84 @@ namespace ggml_cuda_mma {
101115            } else  if  constexpr  (I == 32  && J == 32 ) {
102116                return  threadIdx .x  % 32 ;
103117            } else  {
104-                 static_assert (I == -1  && J == -1 , " template specialization not implemented"  );
118+                 NO_DEVICE_CODE;
119+                 return  -1 ;
120+             }
121+         }
122+ #elif  __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
123+         static  constexpr  int  ne = I * J / 32 ;
124+         T x[ne] = {0 };
125+ 
126+         static  constexpr  __device__  bool  supported () {
127+             if  (I == 32  && J ==  8 ) return  true ;
128+             return  false ;
129+         }
130+ 
131+         static  __device__  __forceinline__  int  get_i (const  int  l) {
132+             if  constexpr  (I == 32  && J == 8 ) {
133+ #ifdef  GGML_CUDA_MMA_NO_VOLTA_PERM
134+                 return  (((threadIdx .x  % 16 ) / 4 ) * 8 ) | ((threadIdx .x  / 16 ) * 4 ) | (l & 2 ) | (threadIdx .x  % 2 );
135+ #else 
136+                 return  (l & 2 ) | (threadIdx .x  & ~2 );
137+ #endif  //  GGML_CUDA_MMA_NO_VOLTA_PERM
138+             } else  {
139+                 NO_DEVICE_CODE;
140+                 return  -1 ;
141+             }
142+         }
143+ 
144+         static  __device__  __forceinline__  int  get_j (const  int  l) {
145+             if  constexpr  (I == 32  && J == 8 ) {
146+                 return  (threadIdx .x  & 2 ) | (l & (4  + 1 ));
147+             } else  {
148+                 NO_DEVICE_CODE;
149+                 return  -1 ;
105150            }
106151        }
107152#else 
108153        static  constexpr  int  ne = I * J / 32 ;
109154        T x[ne] = {0 };
110155
156+         static  constexpr  __device__  bool  supported () {
157+             if  (I ==  8  && J ==  4 ) return  true ;
158+             if  (I ==  8  && J ==  8 ) return  true ;
159+             if  (I == 16  && J ==  8 ) return  true ;
160+             if  (I == 16  && J == 16 ) return  true ;
161+             if  (I == 32  && J ==  8 ) return  true ;
162+             return  false ;
163+         }
164+ 
111165        static  __device__  __forceinline__  int  get_i (const  int  l) {
112-             if  constexpr  (I == 8  && (J == 4  || J == 8 )) {
166+             if  constexpr  (I == 8  && J == 4 ) {
167+                 return  threadIdx .x  / 4 ;
168+             } else  if  constexpr  (I == 8  && J == 8 ) {
113169                return  threadIdx .x  / 4 ;
114170            } else  if  constexpr  (I == 16  && J == 8 ) {
115-                 return  (l / 2 ) * 8  +  threadIdx .x  / 4 ;
171+                 return  (( l / 2 ) * 8 ) | ( threadIdx .x  / 4 ) ;
116172            } else  if  constexpr  (I == 16  && J == 16 ) {
117-                 return  ((l / 2 ) % 2 ) * 8  + threadIdx .x  / 4 ;
173+                 return  (((l / 2 ) % 2 ) * 8 ) | (threadIdx .x  / 4 );
174+             } else  if  constexpr  (I == 32  && J == 8 ) {
175+                 return  tile<16 , 8 , T>::get_i (l); //  Memory layout simply repeated with same pattern in i direction.
118176            } else  {
119-                 static_assert (I == -1  && J == -1 , " template specialization not implemented"  );
177+                 NO_DEVICE_CODE;
178+                 return  -1 ;
120179            }
121180        }
122181
123182        static  __device__  __forceinline__  int  get_j (const  int  l) {
124183            if  constexpr  (I == 8  && J == 4 ) {
125184                return  threadIdx .x  % 4 ;
126185            } else  if  constexpr  (I == 8  && J == 8 ) {
127-                 return  4  * l +  threadIdx .x  % 4 ;
186+                 return  (l  * 4 ) | ( threadIdx .x  % 4 ) ;
128187            } else  if  constexpr  (I == 16  && J == 8 ) {
129-                 return  2  * ( threadIdx .x  % 4 ) +  l % 2 ;
188+                 return  (( threadIdx .x  % 4 ) *  2 ) | ( l % 2 ) ;
130189            } else  if  constexpr  (I == 16  && J == 16 ) {
131-                 return  8  * (l / 4 ) + 2  * (threadIdx .x  % 4 ) + l % 2 ;
190+                 return  ((l / 4 ) * 8 ) | ((threadIdx .x  % 4 ) * 2 ) | (l % 2 );
191+             } else  if  constexpr  (I == 32  && J == 8 ) {
192+                 return  tile<16 , 8 , T>::get_j (l); //  Memory layout simply repeated with same pattern in i direction.
132193            } else  {
133-                 static_assert (I == -1  && J == -1 , " template specialization not implemented"  );
194+                 NO_DEVICE_CODE;
195+                 return  -1 ;
134196            }
135197        }
136198#endif  //  defined(GGML_USE_HIP)
@@ -140,32 +202,83 @@ namespace ggml_cuda_mma {
140202    struct  tile <I_, J_, half2> {
141203        static  constexpr  int  I  = I_;
142204        static  constexpr  int  J  = J_;
205+ 
206+ #if  __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
207+         static  constexpr  int  ne = I == 8  && J == 8  ? I * J / (WARP_SIZE/4 ) : I * J / WARP_SIZE;
208+         half2 x[ne] = {{0 .0f , 0 .0f }};
209+ 
210+         static  constexpr  __device__  bool  supported () {
211+             if  (I ==  8  && J ==  8 ) return  true ;
212+             if  (I == 32  && J ==  8 ) return  true ;
213+             return  false ;
214+         }
215+ 
216+         static  __device__  __forceinline__  int  get_i (const  int  l) {
217+             if  constexpr  (I == 8  && J == 8 ) {
218+                 return  ((threadIdx .x  / 16 ) * 4 ) | (threadIdx .x  % 4 );
219+             } else  if  constexpr  (I == 32  && J == 8 ) {
220+ #ifdef  GGML_CUDA_MMA_NO_VOLTA_PERM
221+                 return  (((threadIdx .x  % 16 ) / 4 ) * 8 ) | ((threadIdx .x  / 16 ) * 4 ) | (threadIdx .x  % 4 );
222+ #else 
223+                 return  threadIdx .x ;
224+ #endif  //  GGML_CUDA_MMA_NO_VOLTA_PERM
225+             } else  {
226+                 NO_DEVICE_CODE;
227+                 return  -1 ;
228+             }
229+         }
230+ 
231+         static  __device__  __forceinline__  int  get_j (const  int  l) {
232+             if  constexpr  ((I == 8  || I == 32 ) && J == 8 ) {
233+                 return  l;
234+             } else  {
235+                 NO_DEVICE_CODE;
236+                 return  -1 ;
237+             }
238+         }
239+ #else 
143240        static  constexpr  int  ne = I * J / WARP_SIZE;
144241        half2 x[ne] = {{0 .0f , 0 .0f }};
145242
243+         static  constexpr  __device__  bool  supported () {
244+             if  (I ==  8  && J ==  4 ) return  true ;
245+             if  (I ==  8  && J ==  8 ) return  true ;
246+             if  (I == 16  && J ==  8 ) return  true ;
247+             if  (I == 16  && J == 16 ) return  true ;
248+             if  (I == 32  && J ==  8 ) return  true ;
249+             return  false ;
250+         }
251+ 
146252        static  __device__  __forceinline__  int  get_i (const  int  l) {
147253            if  constexpr  (I == 8  && J == 8 ) {
148254                return  threadIdx .x  / 4 ;
149255            } else  if  constexpr  (I == 16  && J == 4 ) {
150-                 return  l * 8  +  threadIdx .x  / 4 ;
256+                 return  ( l * 8 ) | ( threadIdx .x  / 4 ) ;
151257            } else  if  constexpr  (I == 16  && J == 8 ) {
152-                 return  (l % 2 ) * 8  + threadIdx .x  / 4 ;
258+                 return  ((l % 2 ) * 8 ) | (threadIdx .x  / 4 );
259+             } else  if  constexpr  (I == 32  && J == 8 ) {
260+                 return  ((l / 4 ) * 16 ) | ((l % 2 ) * 8 ) | (threadIdx .x  / 4 );
153261            } else  {
154-                 static_assert (I == -1  && J == -1 , " template specialization not implemented"  );
262+                 NO_DEVICE_CODE;
263+                 return  -1 ;
155264            }
156265        }
157266
158267        static  __device__  __forceinline__  int  get_j (const  int  l) {
159268            if  constexpr  (I == 8  && J == 8 ) {
160-                 return  l * 4  +  threadIdx .x  % 4 ;
269+                 return  ( l * 4 ) | ( threadIdx .x  % 4 ) ;
161270            } else  if  constexpr  (I == 16  && J == 4 ) {
162271                return  threadIdx .x  % 4 ;
163272            } else  if  constexpr  (I == 16  && J == 8 ) {
164-                 return  (l / 2 ) * 4  + threadIdx .x  % 4 ;
273+                 return  ((l / 2 ) * 4 ) | (threadIdx .x  % 4 );
274+             } else  if  constexpr  (I == 32  && J == 8 ) {
275+                 return  ((l & 2 ) * 2 ) | (threadIdx .x  % 4 );
165276            } else  {
166-                 static_assert (I == -1  && J == -1 , " template specialization not implemented"  );
277+                 NO_DEVICE_CODE;
278+                 return  -1 ;
167279            }
168280        }
281+ #endif  //  __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
169282    };
170283
171284    template  <int  I_, int  J_>
@@ -175,27 +288,36 @@ namespace ggml_cuda_mma {
175288        static  constexpr  int  ne = I * J / WARP_SIZE;
176289        nv_bfloat162 x[ne] = {{0 .0f , 0 .0f }};
177290
291+         static  constexpr  __device__  bool  supported () {
292+             if  (I ==  8  && J ==  8 ) return  true ;
293+             if  (I == 16  && J ==  4 ) return  true ;
294+             if  (I == 16  && J ==  8 ) return  true ;
295+             return  false ;
296+         }
297+ 
178298        static  __device__  __forceinline__  int  get_i (const  int  l) {
179299            if  constexpr  (I == 8  && J == 8 ) {
180300                return  threadIdx .x  / 4 ;
181301            } else  if  constexpr  (I == 16  && J == 4 ) {
182-                 return  l * 8  +  threadIdx .x  / 4 ;
302+                 return  ( l * 8 ) | ( threadIdx .x  / 4 ) ;
183303            } else  if  constexpr  (I == 16  && J == 8 ) {
184-                 return  (l % 2 ) * 8  +  threadIdx .x  / 4 ;
304+                 return  (( l % 2 ) * 8 ) | ( threadIdx .x  / 4 ) ;
185305            } else  {
186-                 static_assert (I == -1  && J == -1 , " template specialization not implemented"  );
306+                 NO_DEVICE_CODE;
307+                 return  -1 ;
187308            }
188309        }
189310
190311        static  __device__  __forceinline__  int  get_j (const  int  l) {
191312            if  constexpr  (I == 8  && J == 8 ) {
192-                 return  l * 4  +  threadIdx .x  % 4 ;
313+                 return  ( l * 4 ) | ( threadIdx .x  % 4 ) ;
193314            } else  if  constexpr  (I == 16  && J == 4 ) {
194315                return  threadIdx .x  % 4 ;
195316            } else  if  constexpr  (I == 16  && J == 8 ) {
196-                 return  (l / 2 ) * 4  +  threadIdx .x  % 4 ;
317+                 return  (( l / 2 ) * 4 ) | ( threadIdx .x  % 4 ) ;
197318            } else  {
198-                 static_assert (I == -1  && J == -1 , " template specialization not implemented"  );
319+                 NO_DEVICE_CODE;
320+                 return  -1 ;
199321            }
200322        }
201323    };
@@ -263,8 +385,12 @@ namespace ggml_cuda_mma {
263385            : " =r"  (xi[0 ]), " =r"  (xi[1 ])
264386            : " l"  (xs));
265387#else 
266-         load_generic (xs0, stride);
267-         GGML_UNUSED (t);
388+ #if  __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
389+         GGML_UNUSED_VARS (t, xs0, stride);
390+         NO_DEVICE_CODE;
391+ #else 
392+         load_generic (t, xs0, stride);
393+ #endif  //  __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
268394#endif  //  TURING_MMA_AVAILABLE
269395    }
270396
@@ -277,11 +403,35 @@ namespace ggml_cuda_mma {
277403        asm  volatile (" ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];" 
278404            : " =r"  (xi[0 ]), " =r"  (xi[1 ]), " =r"  (xi[2 ]), " =r"  (xi[3 ])
279405            : " l"  (xs));
406+ #else 
407+ #if  __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
408+         GGML_UNUSED_VARS (t, xs0, stride);
409+         NO_DEVICE_CODE;
280410#else 
281411        load_generic (t, xs0, stride);
412+ #endif  //  __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
282413#endif  //  TURING_MMA_AVAILABLE
283414    }
284415
416+     template  <typename  T>
417+     static  __device__  __forceinline__  void  load_ldmatrix (
418+             tile<32 , 8 , T> & t, const  T * __restrict__  xs0, const  int  stride) {
419+ #if  __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
420+ #if  1 
421+         //  TODO: more generic handling
422+         static_assert (sizeof (T) == 4 , " bad type size"  );
423+         ggml_cuda_memcpy_1<4 *sizeof (T)>(t.x  + 0 , xs0 + t.get_i (0 )*stride + 0 );
424+         ggml_cuda_memcpy_1<4 *sizeof (T)>(t.x  + 4 , xs0 + t.get_i (4 )*stride + 4 );
425+ #else 
426+         load_generic(t, xs0, stride); 
427+  #endif  //  1
428+ #else 
429+         tile<16 , 8 , T> * t16 = (tile<16 , 8 , T> *) &t;
430+         load_ldmatrix (t16[0 ], xs0 +  0 *stride, stride);
431+         load_ldmatrix (t16[1 ], xs0 + 16 *stride, stride);
432+ #endif  //  __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
433+     }
434+ 
285435    template  <typename  T>
286436    static  __device__  __forceinline__  void  load_ldmatrix_trans (
287437            tile<16 , 8 , T> & t, const  T * __restrict__  xs0, const  int  stride) {
@@ -546,4 +696,43 @@ namespace ggml_cuda_mma {
546696        NO_DEVICE_CODE;
547697#endif  //  AMD_MFMA_AVAILABLE
548698    }
699+ 
700+     template  <typename  T1, typename  T2, int  J, int  K>
701+     static  __device__  __forceinline__  void  mma (
702+             tile<32 , J, T1> & D, const  tile<32 , K, T2> & A, const  tile<J, K, T2> & B) {
703+         tile<16 , J, T1> * D16 = (tile<16 , J, T1> *) &D;
704+         tile<16 , K, T2> * A16 = (tile<16 , K, T2> *) &A;
705+         mma (D16[0 ], A16[0 ], B);
706+         mma (D16[1 ], A16[1 ], B);
707+     }
708+ 
709+     static  __device__  __forceinline__  void  mma (
710+             tile<32 , 8 , float > & D, const  tile<32 , 8 , half2> & A, const  tile<8 , 8 , half2> & B) {
711+ #if  __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
712+         const  int  * Axi = (const  int  *) A.x ;
713+         const  int  * Bxi = (const  int  *) B.x ;
714+         int        * Dxi = (int        *) D.x ;
715+         asm (" mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 " 
716+             " {%0, %1, %2, %3, %4, %5, %6, %7}, {%8, %9}, {%10, %11}, {%0, %1, %2, %3, %4, %5, %6, %7};" 
717+             : " +r"  (Dxi[0 ]), " +r"  (Dxi[1 ]), " +r"  (Dxi[2 ]), " +r"  (Dxi[3 ]), " +r"  (Dxi[4 ]), " +r"  (Dxi[5 ]), " +r"  (Dxi[6 ]), " +r"  (Dxi[7 ])
718+             : " r"  (Axi[0 ]), " r"  (Axi[1 ]), " r"  (Bxi[0 ]), " r"  (Bxi[1 ]));
719+         asm (" mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 " 
720+             " {%0, %1, %2, %3, %4, %5, %6, %7}, {%8, %9}, {%10, %11}, {%0, %1, %2, %3, %4, %5, %6, %7};" 
721+             : " +r"  (Dxi[0 ]), " +r"  (Dxi[1 ]), " +r"  (Dxi[2 ]), " +r"  (Dxi[3 ]), " +r"  (Dxi[4 ]), " +r"  (Dxi[5 ]), " +r"  (Dxi[6 ]), " +r"  (Dxi[7 ])
722+             : " r"  (Axi[2 ]), " r"  (Axi[3 ]), " r"  (Bxi[2 ]), " r"  (Bxi[3 ]));
723+         asm (" mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 " 
724+             " {%0, %1, %2, %3, %4, %5, %6, %7}, {%8, %9}, {%10, %11}, {%0, %1, %2, %3, %4, %5, %6, %7};" 
725+             : " +r"  (Dxi[0 ]), " +r"  (Dxi[1 ]), " +r"  (Dxi[2 ]), " +r"  (Dxi[3 ]), " +r"  (Dxi[4 ]), " +r"  (Dxi[5 ]), " +r"  (Dxi[6 ]), " +r"  (Dxi[7 ])
726+             : " r"  (Axi[4 ]), " r"  (Axi[5 ]), " r"  (Bxi[4 ]), " r"  (Bxi[5 ]));
727+         asm (" mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 " 
728+             " {%0, %1, %2, %3, %4, %5, %6, %7}, {%8, %9}, {%10, %11}, {%0, %1, %2, %3, %4, %5, %6, %7};" 
729+             : " +r"  (Dxi[0 ]), " +r"  (Dxi[1 ]), " +r"  (Dxi[2 ]), " +r"  (Dxi[3 ]), " +r"  (Dxi[4 ]), " +r"  (Dxi[5 ]), " +r"  (Dxi[6 ]), " +r"  (Dxi[7 ])
730+             : " r"  (Axi[6 ]), " r"  (Axi[7 ]), " r"  (Bxi[6 ]), " r"  (Bxi[7 ]));
731+ #else 
732+         tile<16 , 8 , float > * D16 = (tile<16 , 8 , float > *) &D;
733+         tile<16 , 8 , half2> * A16 = (tile<16 , 8 , half2> *) &A;
734+         mma (D16[0 ], A16[0 ], B);
735+         mma (D16[1 ], A16[1 ], B);
736+ #endif  //  __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
737+     }
549738}
0 commit comments