Skip to content

Commit 31c511a

Browse files
CUDA: Volta tensor core support for MMF (ggml-org#16843)
* CUDA: Volta tensor core support for MMF * more generic checks for hardware support * Update ggml/src/ggml-cuda/mmf.cuh Co-authored-by: Aman Gupta <[email protected]> --------- Co-authored-by: Aman Gupta <[email protected]>
1 parent 6d39015 commit 31c511a

File tree

4 files changed

+254
-36
lines changed

4 files changed

+254
-36
lines changed

ggml/src/ggml-cuda/common.cuh

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,11 @@ static const char * cu_get_error_str(CUresult err) {
224224
#define AMD_MFMA_AVAILABLE
225225
#endif // defined(GGML_USE_HIP) && defined(CDNA) && !defined(GGML_HIP_NO_MMQ_MFMA)
226226

227+
// The Volta instructions are in principle available on Turing or newer but they are effectively unusable:
228+
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
229+
#define VOLTA_MMA_AVAILABLE
230+
#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
231+
227232
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
228233
#define TURING_MMA_AVAILABLE
229234
#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
@@ -278,7 +283,10 @@ static bool amd_mfma_available(const int cc) {
278283
#endif //!defined(GGML_HIP_NO_MMQ_MFMA)
279284
}
280285

281-
// Volta technically had FP16 tensor cores but they work very differently compared to Turing and later.
286+
static bool volta_mma_available(const int cc) {
287+
return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_VOLTA;
288+
}
289+
282290
static bool turing_mma_available(const int cc) {
283291
return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_TURING;
284292
}

ggml/src/ggml-cuda/mma.cuh

Lines changed: 213 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@
1818

1919
#include "common.cuh"
2020

21+
// On Volta each warp is doing 4 8x8 mma operations in parallel.
22+
// The basic memory layout for a 32x8 output tile is to stack 4 input tiles in I direction and to mirror the B tile.
23+
// However, the i indices in this file are by default permuted to simplify the index calculations.
24+
// #define GGML_CUDA_MMA_NO_VOLTA_PERM
2125

2226
#if CUDART_VERSION >= 11080
2327

@@ -73,6 +77,15 @@ namespace ggml_cuda_mma {
7377
static constexpr int ne = I * J / 64;
7478
T x[ne] = {0};
7579

80+
static constexpr __device__ bool supported() {
81+
if (I == 64 && J == 2) return true;
82+
if (I == 16 && J == 8) return true;
83+
if (I == 32 && J == 4) return true;
84+
if (I == 16 && J == 16) return true;
85+
if (I == 32 && J == 32) return true;
86+
return false;
87+
}
88+
7689
static __device__ __forceinline__ int get_i(const int l) {
7790
if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8>
7891
return threadIdx.x % 16;
@@ -85,7 +98,8 @@ namespace ggml_cuda_mma {
8598
} else if constexpr (I == 32 && J == 32) {
8699
return 4 * (threadIdx.x / 32) + 8 * (l / 4) + (l % 4);
87100
} else {
88-
static_assert(I == -1 && J == -1, "template specialization not implemented");
101+
NO_DEVICE_CODE;
102+
return -1;
89103
}
90104
}
91105

@@ -101,36 +115,84 @@ namespace ggml_cuda_mma {
101115
} else if constexpr (I == 32 && J == 32) {
102116
return threadIdx.x % 32;
103117
} else {
104-
static_assert(I == -1 && J == -1, "template specialization not implemented");
118+
NO_DEVICE_CODE;
119+
return -1;
120+
}
121+
}
122+
#elif __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
123+
static constexpr int ne = I * J / 32;
124+
T x[ne] = {0};
125+
126+
static constexpr __device__ bool supported() {
127+
if (I == 32 && J == 8) return true;
128+
return false;
129+
}
130+
131+
static __device__ __forceinline__ int get_i(const int l) {
132+
if constexpr (I == 32 && J == 8) {
133+
#ifdef GGML_CUDA_MMA_NO_VOLTA_PERM
134+
return (((threadIdx.x % 16) / 4) * 8) | ((threadIdx.x / 16) * 4) | (l & 2) | (threadIdx.x % 2);
135+
#else
136+
return (l & 2) | (threadIdx.x & ~2);
137+
#endif // GGML_CUDA_MMA_NO_VOLTA_PERM
138+
} else {
139+
NO_DEVICE_CODE;
140+
return -1;
141+
}
142+
}
143+
144+
static __device__ __forceinline__ int get_j(const int l) {
145+
if constexpr (I == 32 && J == 8) {
146+
return (threadIdx.x & 2) | (l & (4 + 1));
147+
} else {
148+
NO_DEVICE_CODE;
149+
return -1;
105150
}
106151
}
107152
#else
108153
static constexpr int ne = I * J / 32;
109154
T x[ne] = {0};
110155

156+
static constexpr __device__ bool supported() {
157+
if (I == 8 && J == 4) return true;
158+
if (I == 8 && J == 8) return true;
159+
if (I == 16 && J == 8) return true;
160+
if (I == 16 && J == 16) return true;
161+
if (I == 32 && J == 8) return true;
162+
return false;
163+
}
164+
111165
static __device__ __forceinline__ int get_i(const int l) {
112-
if constexpr (I == 8 && (J == 4 || J == 8)) {
166+
if constexpr (I == 8 && J == 4) {
167+
return threadIdx.x / 4;
168+
} else if constexpr (I == 8 && J == 8) {
113169
return threadIdx.x / 4;
114170
} else if constexpr (I == 16 && J == 8) {
115-
return (l / 2) * 8 + threadIdx.x / 4;
171+
return ((l / 2) * 8) | (threadIdx.x / 4);
116172
} else if constexpr (I == 16 && J == 16) {
117-
return ((l / 2) % 2) * 8 + threadIdx.x / 4;
173+
return (((l / 2) % 2) * 8) | (threadIdx.x / 4);
174+
} else if constexpr (I == 32 && J == 8) {
175+
return tile<16, 8, T>::get_i(l); // Memory layout simply repeated with same pattern in i direction.
118176
} else {
119-
static_assert(I == -1 && J == -1, "template specialization not implemented");
177+
NO_DEVICE_CODE;
178+
return -1;
120179
}
121180
}
122181

123182
static __device__ __forceinline__ int get_j(const int l) {
124183
if constexpr (I == 8 && J == 4) {
125184
return threadIdx.x % 4;
126185
} else if constexpr (I == 8 && J == 8) {
127-
return 4 * l + threadIdx.x % 4;
186+
return (l * 4) | (threadIdx.x % 4);
128187
} else if constexpr (I == 16 && J == 8) {
129-
return 2 * (threadIdx.x % 4) + l % 2;
188+
return ((threadIdx.x % 4) * 2) | (l % 2);
130189
} else if constexpr (I == 16 && J == 16) {
131-
return 8 * (l / 4) + 2 * (threadIdx.x % 4) + l % 2;
190+
return ((l / 4) * 8) | ((threadIdx.x % 4) * 2) | (l % 2);
191+
} else if constexpr (I == 32 && J == 8) {
192+
return tile<16, 8, T>::get_j(l); // Memory layout simply repeated with same pattern in i direction.
132193
} else {
133-
static_assert(I == -1 && J == -1, "template specialization not implemented");
194+
NO_DEVICE_CODE;
195+
return -1;
134196
}
135197
}
136198
#endif // defined(GGML_USE_HIP)
@@ -140,32 +202,83 @@ namespace ggml_cuda_mma {
140202
struct tile<I_, J_, half2> {
141203
static constexpr int I = I_;
142204
static constexpr int J = J_;
205+
206+
#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
207+
static constexpr int ne = I == 8 && J == 8 ? I * J / (WARP_SIZE/4) : I * J / WARP_SIZE;
208+
half2 x[ne] = {{0.0f, 0.0f}};
209+
210+
static constexpr __device__ bool supported() {
211+
if (I == 8 && J == 8) return true;
212+
if (I == 32 && J == 8) return true;
213+
return false;
214+
}
215+
216+
static __device__ __forceinline__ int get_i(const int l) {
217+
if constexpr (I == 8 && J == 8) {
218+
return ((threadIdx.x / 16) * 4) | (threadIdx.x % 4);
219+
} else if constexpr (I == 32 && J == 8) {
220+
#ifdef GGML_CUDA_MMA_NO_VOLTA_PERM
221+
return (((threadIdx.x % 16) / 4) * 8) | ((threadIdx.x / 16) * 4) | (threadIdx.x % 4);
222+
#else
223+
return threadIdx.x;
224+
#endif // GGML_CUDA_MMA_NO_VOLTA_PERM
225+
} else {
226+
NO_DEVICE_CODE;
227+
return -1;
228+
}
229+
}
230+
231+
static __device__ __forceinline__ int get_j(const int l) {
232+
if constexpr ((I == 8 || I == 32) && J == 8) {
233+
return l;
234+
} else {
235+
NO_DEVICE_CODE;
236+
return -1;
237+
}
238+
}
239+
#else
143240
static constexpr int ne = I * J / WARP_SIZE;
144241
half2 x[ne] = {{0.0f, 0.0f}};
145242

243+
static constexpr __device__ bool supported() {
244+
if (I == 8 && J == 4) return true;
245+
if (I == 8 && J == 8) return true;
246+
if (I == 16 && J == 8) return true;
247+
if (I == 16 && J == 16) return true;
248+
if (I == 32 && J == 8) return true;
249+
return false;
250+
}
251+
146252
static __device__ __forceinline__ int get_i(const int l) {
147253
if constexpr (I == 8 && J == 8) {
148254
return threadIdx.x / 4;
149255
} else if constexpr (I == 16 && J == 4) {
150-
return l * 8 + threadIdx.x / 4;
256+
return (l * 8) | (threadIdx.x / 4);
151257
} else if constexpr (I == 16 && J == 8) {
152-
return (l % 2) * 8 + threadIdx.x / 4;
258+
return ((l % 2) * 8) | (threadIdx.x / 4);
259+
} else if constexpr (I == 32 && J == 8) {
260+
return ((l / 4) * 16) | ((l % 2) * 8) | (threadIdx.x / 4);
153261
} else {
154-
static_assert(I == -1 && J == -1, "template specialization not implemented");
262+
NO_DEVICE_CODE;
263+
return -1;
155264
}
156265
}
157266

158267
static __device__ __forceinline__ int get_j(const int l) {
159268
if constexpr (I == 8 && J == 8) {
160-
return l * 4 + threadIdx.x % 4;
269+
return (l * 4) | (threadIdx.x % 4);
161270
} else if constexpr (I == 16 && J == 4) {
162271
return threadIdx.x % 4;
163272
} else if constexpr (I == 16 && J == 8) {
164-
return (l / 2) * 4 + threadIdx.x % 4;
273+
return ((l / 2) * 4) | (threadIdx.x % 4);
274+
} else if constexpr (I == 32 && J == 8) {
275+
return ((l & 2) * 2) | (threadIdx.x % 4);
165276
} else {
166-
static_assert(I == -1 && J == -1, "template specialization not implemented");
277+
NO_DEVICE_CODE;
278+
return -1;
167279
}
168280
}
281+
#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
169282
};
170283

171284
template <int I_, int J_>
@@ -175,27 +288,36 @@ namespace ggml_cuda_mma {
175288
static constexpr int ne = I * J / WARP_SIZE;
176289
nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
177290

291+
static constexpr __device__ bool supported() {
292+
if (I == 8 && J == 8) return true;
293+
if (I == 16 && J == 4) return true;
294+
if (I == 16 && J == 8) return true;
295+
return false;
296+
}
297+
178298
static __device__ __forceinline__ int get_i(const int l) {
179299
if constexpr (I == 8 && J == 8) {
180300
return threadIdx.x / 4;
181301
} else if constexpr (I == 16 && J == 4) {
182-
return l * 8 + threadIdx.x / 4;
302+
return (l * 8) | (threadIdx.x / 4);
183303
} else if constexpr (I == 16 && J == 8) {
184-
return (l % 2) * 8 + threadIdx.x / 4;
304+
return ((l % 2) * 8) | (threadIdx.x / 4);
185305
} else {
186-
static_assert(I == -1 && J == -1, "template specialization not implemented");
306+
NO_DEVICE_CODE;
307+
return -1;
187308
}
188309
}
189310

190311
static __device__ __forceinline__ int get_j(const int l) {
191312
if constexpr (I == 8 && J == 8) {
192-
return l * 4 + threadIdx.x % 4;
313+
return (l * 4) | (threadIdx.x % 4);
193314
} else if constexpr (I == 16 && J == 4) {
194315
return threadIdx.x % 4;
195316
} else if constexpr (I == 16 && J == 8) {
196-
return (l / 2) * 4 + threadIdx.x % 4;
317+
return ((l / 2) * 4) | (threadIdx.x % 4);
197318
} else {
198-
static_assert(I == -1 && J == -1, "template specialization not implemented");
319+
NO_DEVICE_CODE;
320+
return -1;
199321
}
200322
}
201323
};
@@ -263,8 +385,12 @@ namespace ggml_cuda_mma {
263385
: "=r"(xi[0]), "=r"(xi[1])
264386
: "l"(xs));
265387
#else
266-
load_generic(xs0, stride);
267-
GGML_UNUSED(t);
388+
#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
389+
GGML_UNUSED_VARS(t, xs0, stride);
390+
NO_DEVICE_CODE;
391+
#else
392+
load_generic(t, xs0, stride);
393+
#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
268394
#endif // TURING_MMA_AVAILABLE
269395
}
270396

@@ -277,11 +403,35 @@ namespace ggml_cuda_mma {
277403
asm volatile("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];"
278404
: "=r"(xi[0]), "=r"(xi[1]), "=r"(xi[2]), "=r"(xi[3])
279405
: "l"(xs));
406+
#else
407+
#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
408+
GGML_UNUSED_VARS(t, xs0, stride);
409+
NO_DEVICE_CODE;
280410
#else
281411
load_generic(t, xs0, stride);
412+
#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
282413
#endif // TURING_MMA_AVAILABLE
283414
}
284415

416+
template <typename T>
417+
static __device__ __forceinline__ void load_ldmatrix(
418+
tile<32, 8, T> & t, const T * __restrict__ xs0, const int stride) {
419+
#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
420+
#if 1
421+
// TODO: more generic handling
422+
static_assert(sizeof(T) == 4, "bad type size");
423+
ggml_cuda_memcpy_1<4*sizeof(T)>(t.x + 0, xs0 + t.get_i(0)*stride + 0);
424+
ggml_cuda_memcpy_1<4*sizeof(T)>(t.x + 4, xs0 + t.get_i(4)*stride + 4);
425+
#else
426+
load_generic(t, xs0, stride);
427+
#endif // 1
428+
#else
429+
tile<16, 8, T> * t16 = (tile<16, 8, T> *) &t;
430+
load_ldmatrix(t16[0], xs0 + 0*stride, stride);
431+
load_ldmatrix(t16[1], xs0 + 16*stride, stride);
432+
#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
433+
}
434+
285435
template <typename T>
286436
static __device__ __forceinline__ void load_ldmatrix_trans(
287437
tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) {
@@ -546,4 +696,43 @@ namespace ggml_cuda_mma {
546696
NO_DEVICE_CODE;
547697
#endif // AMD_MFMA_AVAILABLE
548698
}
699+
700+
template <typename T1, typename T2, int J, int K>
701+
static __device__ __forceinline__ void mma(
702+
tile<32, J, T1> & D, const tile<32, K, T2> & A, const tile<J, K, T2> & B) {
703+
tile<16, J, T1> * D16 = (tile<16, J, T1> *) &D;
704+
tile<16, K, T2> * A16 = (tile<16, K, T2> *) &A;
705+
mma(D16[0], A16[0], B);
706+
mma(D16[1], A16[1], B);
707+
}
708+
709+
static __device__ __forceinline__ void mma(
710+
tile<32, 8, float> & D, const tile<32, 8, half2> & A, const tile<8, 8, half2> & B) {
711+
#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
712+
const int * Axi = (const int *) A.x;
713+
const int * Bxi = (const int *) B.x;
714+
int * Dxi = (int *) D.x;
715+
asm("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 "
716+
"{%0, %1, %2, %3, %4, %5, %6, %7}, {%8, %9}, {%10, %11}, {%0, %1, %2, %3, %4, %5, %6, %7};"
717+
: "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]), "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
718+
: "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]), "r"(Bxi[1]));
719+
asm("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 "
720+
"{%0, %1, %2, %3, %4, %5, %6, %7}, {%8, %9}, {%10, %11}, {%0, %1, %2, %3, %4, %5, %6, %7};"
721+
: "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]), "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
722+
: "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[2]), "r"(Bxi[3]));
723+
asm("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 "
724+
"{%0, %1, %2, %3, %4, %5, %6, %7}, {%8, %9}, {%10, %11}, {%0, %1, %2, %3, %4, %5, %6, %7};"
725+
: "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]), "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
726+
: "r"(Axi[4]), "r"(Axi[5]), "r"(Bxi[4]), "r"(Bxi[5]));
727+
asm("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 "
728+
"{%0, %1, %2, %3, %4, %5, %6, %7}, {%8, %9}, {%10, %11}, {%0, %1, %2, %3, %4, %5, %6, %7};"
729+
: "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]), "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
730+
: "r"(Axi[6]), "r"(Axi[7]), "r"(Bxi[6]), "r"(Bxi[7]));
731+
#else
732+
tile<16, 8, float> * D16 = (tile<16, 8, float> *) &D;
733+
tile<16, 8, half2> * A16 = (tile<16, 8, half2> *) &A;
734+
mma(D16[0], A16[0], B);
735+
mma(D16[1], A16[1], B);
736+
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
737+
}
549738
}

ggml/src/ggml-cuda/mmf.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const
148148
case GGML_TYPE_F32:
149149
return ampere_mma_available(cc);
150150
case GGML_TYPE_F16:
151-
return turing_mma_available(cc);
151+
return volta_mma_available(cc) || turing_mma_available(cc);
152152
case GGML_TYPE_BF16:
153153
return ampere_mma_available(cc);
154154
default:

0 commit comments

Comments
 (0)