Skip to content

Commit 5d14386

Browse files
CUDA: Volta tensor core support for MMF
1 parent f549b00 commit 5d14386

File tree

4 files changed

+257
-25
lines changed

4 files changed

+257
-25
lines changed

ggml/src/ggml-cuda/common.cuh

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,11 @@ static const char * cu_get_error_str(CUresult err) {
224224
#define AMD_MFMA_AVAILABLE
225225
#endif // defined(GGML_USE_HIP) && defined(CDNA) && !defined(GGML_HIP_NO_MMQ_MFMA)
226226

227+
// The Volta instructions are in principle available on Turing or newer but they are effectively unusable:
228+
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
229+
#define VOLTA_MMA_AVAILABLE
230+
#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
231+
227232
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
228233
#define TURING_MMA_AVAILABLE
229234
#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
@@ -278,7 +283,10 @@ static bool amd_mfma_available(const int cc) {
278283
#endif //!defined(GGML_HIP_NO_MMQ_MFMA)
279284
}
280285

281-
// Volta technically had FP16 tensor cores but they work very differently compared to Turing and later.
286+
static bool volta_mma_available(const int cc) {
287+
return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_VOLTA;
288+
}
289+
282290
static bool turing_mma_available(const int cc) {
283291
return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_TURING;
284292
}

ggml/src/ggml-cuda/mma.cuh

Lines changed: 155 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@
1818

1919
#include "common.cuh"
2020

21+
// On Volta each warp is doing 4 8x8 mma operations in parallel.
22+
// The basic memory layout for a 32x8 output tile is to stack 4 input tiles in I direction and to mirror the B tile.
23+
// However, the i indices in this file are by default permuted to simplify the index calculations.
24+
// #define GGML_CUDA_MMA_NO_VOLTA_PERM
2125

2226
#if CUDART_VERSION >= 11080
2327

@@ -86,6 +90,7 @@ namespace ggml_cuda_mma {
8690
return 4 * (threadIdx.x / 32) + 8 * (l / 4) + (l % 4);
8791
} else {
8892
static_assert(I == -1 && J == -1, "template specialization not implemented");
93+
return -1;
8994
}
9095
}
9196

@@ -102,6 +107,32 @@ namespace ggml_cuda_mma {
102107
return threadIdx.x % 32;
103108
} else {
104109
static_assert(I == -1 && J == -1, "template specialization not implemented");
110+
return -1;
111+
}
112+
}
113+
#elif __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
114+
static constexpr int ne = I * J / 32;
115+
T x[ne] = {0};
116+
117+
static __device__ __forceinline__ int get_i(const int l) {
118+
if constexpr (I == 32 && J == 8) {
119+
#ifdef GGML_CUDA_MMA_NO_VOLTA_PERM
120+
return (((threadIdx.x % 16) / 4) * 8) | ((threadIdx.x / 16) * 4) | (l & 2) | (threadIdx.x % 2);
121+
#else
122+
return (l & 2) | (threadIdx.x & ~2);
123+
#endif // GGML_CUDA_MMA_NO_VOLTA_PERM
124+
} else {
125+
static_assert(I == -1 && J == -1, "template specialization not implemented");
126+
return -1;
127+
}
128+
}
129+
130+
static __device__ __forceinline__ int get_j(const int l) {
131+
if constexpr (I == 32 && J == 8) {
132+
return (threadIdx.x & 2) | (l & (4 + 1));
133+
} else {
134+
static_assert(I == -1 && J == -1, "template specialization not implemented");
135+
return -1;
105136
}
106137
}
107138
#else
@@ -111,26 +142,28 @@ namespace ggml_cuda_mma {
111142
static __device__ __forceinline__ int get_i(const int l) {
112143
if constexpr (I == 8 && (J == 4 || J == 8)) {
113144
return threadIdx.x / 4;
114-
} else if constexpr (I == 16 && J == 8) {
115-
return (l / 2) * 8 + threadIdx.x / 4;
145+
} else if constexpr ((I == 16 || I == 32) && J == 8) {
146+
return ((l / 2) * 8) | (threadIdx.x / 4);
116147
} else if constexpr (I == 16 && J == 16) {
117-
return ((l / 2) % 2) * 8 + threadIdx.x / 4;
148+
return (((l / 2) % 2) * 8) | (threadIdx.x / 4);
118149
} else {
119150
static_assert(I == -1 && J == -1, "template specialization not implemented");
151+
return -1;
120152
}
121153
}
122154

123155
static __device__ __forceinline__ int get_j(const int l) {
124156
if constexpr (I == 8 && J == 4) {
125157
return threadIdx.x % 4;
126158
} else if constexpr (I == 8 && J == 8) {
127-
return 4 * l + threadIdx.x % 4;
128-
} else if constexpr (I == 16 && J == 8) {
129-
return 2 * (threadIdx.x % 4) + l % 2;
159+
return (l * 4) | (threadIdx.x % 4);
160+
} else if constexpr ((I == 16 || I == 32) && J == 8) {
161+
return ((threadIdx.x % 4) * 2) | (l % 2);
130162
} else if constexpr (I == 16 && J == 16) {
131-
return 8 * (l / 4) + 2 * (threadIdx.x % 4) + l % 2;
163+
return ((l / 4) * 8) | ((threadIdx.x % 4) * 2) | (l % 2);
132164
} else {
133165
static_assert(I == -1 && J == -1, "template specialization not implemented");
166+
return -1;
134167
}
135168
}
136169
#endif // defined(GGML_USE_HIP)
@@ -140,32 +173,68 @@ namespace ggml_cuda_mma {
140173
struct tile<I_, J_, half2> {
141174
static constexpr int I = I_;
142175
static constexpr int J = J_;
176+
177+
#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
178+
static constexpr int ne = I == 8 && J == 8 ? I * J / (WARP_SIZE/4) : I * J / WARP_SIZE;
179+
half2 x[ne] = {{0.0f, 0.0f}};
180+
181+
static __device__ __forceinline__ int get_i(const int l) {
182+
if constexpr (I == 8 && J == 8) {
183+
return ((threadIdx.x / 16) * 4) | (threadIdx.x % 4);
184+
} else if constexpr (I == 32 && J == 8) {
185+
#ifdef GGML_CUDA_MMA_NO_VOLTA_PERM
186+
return (((threadIdx.x % 16) / 4) * 8) | ((threadIdx.x / 16) * 4) | (threadIdx.x % 4);
187+
#else
188+
return threadIdx.x;
189+
#endif // GGML_CUDA_MMA_NO_VOLTA_PERM
190+
} else {
191+
static_assert(I == -1 && J == -1, "template specialization not implemented");
192+
return -1;
193+
}
194+
}
195+
196+
static __device__ __forceinline__ int get_j(const int l) {
197+
if constexpr ((I == 8 || I == 32) && J == 8) {
198+
return l;
199+
} else {
200+
static_assert(I == -1 && J == -1, "template specialization not implemented");
201+
return -1;
202+
}
203+
}
204+
#else
143205
static constexpr int ne = I * J / WARP_SIZE;
144206
half2 x[ne] = {{0.0f, 0.0f}};
145207

146208
static __device__ __forceinline__ int get_i(const int l) {
147209
if constexpr (I == 8 && J == 8) {
148210
return threadIdx.x / 4;
149211
} else if constexpr (I == 16 && J == 4) {
150-
return l * 8 + threadIdx.x / 4;
212+
return (l * 8) | (threadIdx.x / 4);
151213
} else if constexpr (I == 16 && J == 8) {
152-
return (l % 2) * 8 + threadIdx.x / 4;
214+
return ((l % 2) * 8) | (threadIdx.x / 4);
215+
} else if constexpr (I == 32 && J == 8) {
216+
return ((l / 4) * 16) | ((l % 2) * 8) | (threadIdx.x / 4);
153217
} else {
154218
static_assert(I == -1 && J == -1, "template specialization not implemented");
219+
return -1;
155220
}
156221
}
157222

158223
static __device__ __forceinline__ int get_j(const int l) {
159224
if constexpr (I == 8 && J == 8) {
160-
return l * 4 + threadIdx.x % 4;
225+
return (l * 4) | (threadIdx.x % 4);
161226
} else if constexpr (I == 16 && J == 4) {
162227
return threadIdx.x % 4;
163228
} else if constexpr (I == 16 && J == 8) {
164-
return (l / 2) * 4 + threadIdx.x % 4;
229+
return ((l / 2) * 4) | (threadIdx.x % 4);
230+
} else if constexpr (I == 32 && J == 8) {
231+
return ((l & 2) * 2) | (threadIdx.x % 4);
165232
} else {
166233
static_assert(I == -1 && J == -1, "template specialization not implemented");
234+
return -1;
167235
}
168236
}
237+
#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
169238
};
170239

171240
template <int I_, int J_>
@@ -179,23 +248,25 @@ namespace ggml_cuda_mma {
179248
if constexpr (I == 8 && J == 8) {
180249
return threadIdx.x / 4;
181250
} else if constexpr (I == 16 && J == 4) {
182-
return l * 8 + threadIdx.x / 4;
251+
return (l * 8) | (threadIdx.x / 4);
183252
} else if constexpr (I == 16 && J == 8) {
184-
return (l % 2) * 8 + threadIdx.x / 4;
253+
return ((l % 2) * 8) | (threadIdx.x / 4);
185254
} else {
186255
static_assert(I == -1 && J == -1, "template specialization not implemented");
256+
return -1;
187257
}
188258
}
189259

190260
static __device__ __forceinline__ int get_j(const int l) {
191261
if constexpr (I == 8 && J == 8) {
192-
return l * 4 + threadIdx.x % 4;
262+
return (l * 4) | (threadIdx.x % 4);
193263
} else if constexpr (I == 16 && J == 4) {
194264
return threadIdx.x % 4;
195265
} else if constexpr (I == 16 && J == 8) {
196-
return (l / 2) * 4 + threadIdx.x % 4;
266+
return ((l / 2) * 4) | (threadIdx.x % 4);
197267
} else {
198268
static_assert(I == -1 && J == -1, "template specialization not implemented");
269+
return -1;
199270
}
200271
}
201272
};
@@ -263,8 +334,12 @@ namespace ggml_cuda_mma {
263334
: "=r"(xi[0]), "=r"(xi[1])
264335
: "l"(xs));
265336
#else
266-
load_generic(xs0, stride);
267-
GGML_UNUSED(t);
337+
#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
338+
GGML_UNUSED_VARS(t, xs0, stride);
339+
NO_DEVICE_CODE;
340+
#else
341+
load_generic(t, xs0, stride);
342+
#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
268343
#endif // TURING_MMA_AVAILABLE
269344
}
270345

@@ -277,11 +352,35 @@ namespace ggml_cuda_mma {
277352
asm volatile("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];"
278353
: "=r"(xi[0]), "=r"(xi[1]), "=r"(xi[2]), "=r"(xi[3])
279354
: "l"(xs));
355+
#else
356+
#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
357+
GGML_UNUSED_VARS(t, xs0, stride);
358+
NO_DEVICE_CODE;
280359
#else
281360
load_generic(t, xs0, stride);
361+
#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
282362
#endif // TURING_MMA_AVAILABLE
283363
}
284364

365+
template <typename T>
366+
static __device__ __forceinline__ void load_ldmatrix(
367+
tile<32, 8, T> & t, const T * __restrict__ xs0, const int stride) {
368+
#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
369+
#if 1
370+
// TODO: more generic handling
371+
static_assert(sizeof(T) == 4, "bad type size");
372+
ggml_cuda_memcpy_1<4*sizeof(T)>(t.x + 0, xs0 + t.get_i(0)*stride + 0);
373+
ggml_cuda_memcpy_1<4*sizeof(T)>(t.x + 4, xs0 + t.get_i(4)*stride + 4);
374+
#else
375+
load_generic(t, xs0, stride);
376+
#endif // 1
377+
#else
378+
tile<16, 8, T> * t16 = (tile<16, 8, T> *) &t;
379+
load_ldmatrix(t16[0], xs0 + 0*stride, stride);
380+
load_ldmatrix(t16[1], xs0 + 16*stride, stride);
381+
#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
382+
}
383+
285384
template <typename T>
286385
static __device__ __forceinline__ void load_ldmatrix_trans(
287386
tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) {
@@ -546,4 +645,43 @@ namespace ggml_cuda_mma {
546645
NO_DEVICE_CODE;
547646
#endif // AMD_MFMA_AVAILABLE
548647
}
648+
649+
template <typename T1, typename T2, int J, int K>
650+
static __device__ __forceinline__ void mma(
651+
tile<32, J, T1> & D, const tile<32, K, T2> & A, const tile<J, K, T2> & B) {
652+
tile<16, J, T1> * D16 = (tile<16, J, T1> *) &D;
653+
tile<16, K, T2> * A16 = (tile<16, K, T2> *) &A;
654+
mma(D16[0], A16[0], B);
655+
mma(D16[1], A16[1], B);
656+
}
657+
658+
static __device__ __forceinline__ void mma(
659+
tile<32, 8, float> & D, const tile<32, 8, half2> & A, const tile<8, 8, half2> & B) {
660+
#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
661+
const int * Axi = (const int *) A.x;
662+
const int * Bxi = (const int *) B.x;
663+
int * Dxi = (int *) D.x;
664+
asm("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 "
665+
"{%0, %1, %2, %3, %4, %5, %6, %7}, {%8, %9}, {%10, %11}, {%0, %1, %2, %3, %4, %5, %6, %7};"
666+
: "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]), "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
667+
: "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]), "r"(Bxi[1]));
668+
asm("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 "
669+
"{%0, %1, %2, %3, %4, %5, %6, %7}, {%8, %9}, {%10, %11}, {%0, %1, %2, %3, %4, %5, %6, %7};"
670+
: "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]), "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
671+
: "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[2]), "r"(Bxi[3]));
672+
asm("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 "
673+
"{%0, %1, %2, %3, %4, %5, %6, %7}, {%8, %9}, {%10, %11}, {%0, %1, %2, %3, %4, %5, %6, %7};"
674+
: "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]), "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
675+
: "r"(Axi[4]), "r"(Axi[5]), "r"(Bxi[4]), "r"(Bxi[5]));
676+
asm("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 "
677+
"{%0, %1, %2, %3, %4, %5, %6, %7}, {%8, %9}, {%10, %11}, {%0, %1, %2, %3, %4, %5, %6, %7};"
678+
: "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]), "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
679+
: "r"(Axi[6]), "r"(Axi[7]), "r"(Bxi[6]), "r"(Bxi[7]));
680+
#else
681+
tile<16, 8, float> * D16 = (tile<16, 8, float> *) &D;
682+
tile<16, 8, half2> * A16 = (tile<16, 8, half2> *) &A;
683+
mma(D16[0], A16[0], B);
684+
mma(D16[1], A16[1], B);
685+
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
686+
}
549687
}

ggml/src/ggml-cuda/mmf.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const
148148
case GGML_TYPE_F32:
149149
return ampere_mma_available(cc);
150150
case GGML_TYPE_F16:
151-
return turing_mma_available(cc);
151+
return volta_mma_available(cc) || turing_mma_available(cc);
152152
case GGML_TYPE_BF16:
153153
return ampere_mma_available(cc);
154154
default:

0 commit comments

Comments
 (0)