Skip to content

Commit 7efb6ac

Browse files
more generic checks for hardware support
1 parent 5d14386 commit 7efb6ac

File tree

2 files changed

+100
-114
lines changed

2 files changed

+100
-114
lines changed

ggml/src/ggml-cuda/mma.cuh

Lines changed: 66 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,15 @@ namespace ggml_cuda_mma {
7777
static constexpr int ne = I * J / 64;
7878
T x[ne] = {0};
7979

80+
static constexpr __device__ bool supported() {
81+
if (I == 64 && J == 2) return true;
82+
if (I == 16 && J == 8) return true;
83+
if (I == 32 && J == 4) return true;
84+
if (I == 16 && J == 16) return true;
85+
if (I == 32 && J == 32) return true;
86+
return false;
87+
}
88+
8089
static __device__ __forceinline__ int get_i(const int l) {
8190
if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8>
8291
return threadIdx.x % 16;
@@ -89,7 +98,7 @@ namespace ggml_cuda_mma {
8998
} else if constexpr (I == 32 && J == 32) {
9099
return 4 * (threadIdx.x / 32) + 8 * (l / 4) + (l % 4);
91100
} else {
92-
static_assert(I == -1 && J == -1, "template specialization not implemented");
101+
NO_DEVICE_CODE;
93102
return -1;
94103
}
95104
}
@@ -106,14 +115,19 @@ namespace ggml_cuda_mma {
106115
} else if constexpr (I == 32 && J == 32) {
107116
return threadIdx.x % 32;
108117
} else {
109-
static_assert(I == -1 && J == -1, "template specialization not implemented");
118+
NO_DEVICE_CODE;
110119
return -1;
111120
}
112121
}
113122
#elif __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
114123
static constexpr int ne = I * J / 32;
115124
T x[ne] = {0};
116125

126+
static constexpr __device__ bool supported() {
127+
if (I == 32 && J == 8) return true;
128+
return false;
129+
}
130+
117131
static __device__ __forceinline__ int get_i(const int l) {
118132
if constexpr (I == 32 && J == 8) {
119133
#ifdef GGML_CUDA_MMA_NO_VOLTA_PERM
@@ -122,7 +136,7 @@ namespace ggml_cuda_mma {
122136
return (l & 2) | (threadIdx.x & ~2);
123137
#endif // GGML_CUDA_MMA_NO_VOLTA_PERM
124138
} else {
125-
static_assert(I == -1 && J == -1, "template specialization not implemented");
139+
NO_DEVICE_CODE;
126140
return -1;
127141
}
128142
}
@@ -131,23 +145,36 @@ namespace ggml_cuda_mma {
131145
if constexpr (I == 32 && J == 8) {
132146
return (threadIdx.x & 2) | (l & (4 + 1));
133147
} else {
134-
static_assert(I == -1 && J == -1, "template specialization not implemented");
148+
NO_DEVICE_CODE;
135149
return -1;
136150
}
137151
}
138152
#else
139153
static constexpr int ne = I * J / 32;
140154
T x[ne] = {0};
141155

156+
static constexpr __device__ bool supported() {
157+
if (I == 8 && J == 4) return true;
158+
if (I == 8 && J == 8) return true;
159+
if (I == 16 && J == 8) return true;
160+
if (I == 16 && J == 16) return true;
161+
if (I == 32 && J == 8) return true;
162+
return false;
163+
}
164+
142165
static __device__ __forceinline__ int get_i(const int l) {
143-
if constexpr (I == 8 && (J == 4 || J == 8)) {
166+
if constexpr (I == 8 && J == 4) {
144167
return threadIdx.x / 4;
145-
} else if constexpr ((I == 16 || I == 32) && J == 8) {
168+
} else if constexpr (I == 8 && J == 8) {
169+
return threadIdx.x / 4;
170+
} else if constexpr (I == 16 && J == 8) {
146171
return ((l / 2) * 8) | (threadIdx.x / 4);
147172
} else if constexpr (I == 16 && J == 16) {
148173
return (((l / 2) % 2) * 8) | (threadIdx.x / 4);
174+
} else if constexpr (I == 32 && J == 8) {
175+
return tile<16, 8, T>::get_i(l); // Memory layout simply repeated with same pattern in i direction.
149176
} else {
150-
static_assert(I == -1 && J == -1, "template specialization not implemented");
177+
NO_DEVICE_CODE;
151178
return -1;
152179
}
153180
}
@@ -157,12 +184,14 @@ namespace ggml_cuda_mma {
157184
return threadIdx.x % 4;
158185
} else if constexpr (I == 8 && J == 8) {
159186
return (l * 4) | (threadIdx.x % 4);
160-
} else if constexpr ((I == 16 || I == 32) && J == 8) {
187+
} else if constexpr (I == 16 && J == 8) {
161188
return ((threadIdx.x % 4) * 2) | (l % 2);
162189
} else if constexpr (I == 16 && J == 16) {
163190
return ((l / 4) * 8) | ((threadIdx.x % 4) * 2) | (l % 2);
191+
} else if constexpr (I == 32 && J == 8) {
192+
return tile<16, 8, T>::get_j(l); // Memory layout simply repeated with same pattern in i direction.
164193
} else {
165-
static_assert(I == -1 && J == -1, "template specialization not implemented");
194+
NO_DEVICE_CODE;
166195
return -1;
167196
}
168197
}
@@ -178,6 +207,12 @@ namespace ggml_cuda_mma {
178207
static constexpr int ne = I == 8 && J == 8 ? I * J / (WARP_SIZE/4) : I * J / WARP_SIZE;
179208
half2 x[ne] = {{0.0f, 0.0f}};
180209

210+
static constexpr __device__ bool supported() {
211+
if (I == 8 && J == 8) return true;
212+
if (I == 32 && J == 8) return true;
213+
return false;
214+
}
215+
181216
static __device__ __forceinline__ int get_i(const int l) {
182217
if constexpr (I == 8 && J == 8) {
183218
return ((threadIdx.x / 16) * 4) | (threadIdx.x % 4);
@@ -188,7 +223,7 @@ namespace ggml_cuda_mma {
188223
return threadIdx.x;
189224
#endif // GGML_CUDA_MMA_NO_VOLTA_PERM
190225
} else {
191-
static_assert(I == -1 && J == -1, "template specialization not implemented");
226+
NO_DEVICE_CODE;
192227
return -1;
193228
}
194229
}
@@ -197,14 +232,23 @@ namespace ggml_cuda_mma {
197232
if constexpr ((I == 8 || I == 32) && J == 8) {
198233
return l;
199234
} else {
200-
static_assert(I == -1 && J == -1, "template specialization not implemented");
235+
NO_DEVICE_CODE;
201236
return -1;
202237
}
203238
}
204239
#else
205240
static constexpr int ne = I * J / WARP_SIZE;
206241
half2 x[ne] = {{0.0f, 0.0f}};
207242

243+
static constexpr __device__ bool supported() {
244+
if (I == 8 && J == 4) return true;
245+
if (I == 8 && J == 8) return true;
246+
if (I == 16 && J == 8) return true;
247+
if (I == 16 && J == 16) return true;
248+
if (I == 32 && J == 8) return true;
249+
return false;
250+
}
251+
208252
static __device__ __forceinline__ int get_i(const int l) {
209253
if constexpr (I == 8 && J == 8) {
210254
return threadIdx.x / 4;
@@ -215,7 +259,7 @@ namespace ggml_cuda_mma {
215259
} else if constexpr (I == 32 && J == 8) {
216260
return ((l / 4) * 16) | ((l % 2) * 8) | (threadIdx.x / 4);
217261
} else {
218-
static_assert(I == -1 && J == -1, "template specialization not implemented");
262+
NO_DEVICE_CODE;
219263
return -1;
220264
}
221265
}
@@ -230,7 +274,7 @@ namespace ggml_cuda_mma {
230274
} else if constexpr (I == 32 && J == 8) {
231275
return ((l & 2) * 2) | (threadIdx.x % 4);
232276
} else {
233-
static_assert(I == -1 && J == -1, "template specialization not implemented");
277+
NO_DEVICE_CODE;
234278
return -1;
235279
}
236280
}
@@ -244,6 +288,13 @@ namespace ggml_cuda_mma {
244288
static constexpr int ne = I * J / WARP_SIZE;
245289
nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
246290

291+
static constexpr __device__ bool supported() {
292+
if (I == 8 && J == 8) return true;
293+
if (I == 16 && J == 4) return true;
294+
if (I == 16 && J == 8) return true;
295+
return false;
296+
}
297+
247298
static __device__ __forceinline__ int get_i(const int l) {
248299
if constexpr (I == 8 && J == 8) {
249300
return threadIdx.x / 4;
@@ -252,7 +303,7 @@ namespace ggml_cuda_mma {
252303
} else if constexpr (I == 16 && J == 8) {
253304
return ((l % 2) * 8) | (threadIdx.x / 4);
254305
} else {
255-
static_assert(I == -1 && J == -1, "template specialization not implemented");
306+
NO_DEVICE_CODE;
256307
return -1;
257308
}
258309
}
@@ -265,7 +316,7 @@ namespace ggml_cuda_mma {
265316
} else if constexpr (I == 16 && J == 8) {
266317
return ((l / 2) * 4) | (threadIdx.x % 4);
267318
} else {
268-
static_assert(I == -1 && J == -1, "template specialization not implemented");
319+
NO_DEVICE_CODE;
269320
return -1;
270321
}
271322
}

ggml/src/ggml-cuda/mmf.cuh

Lines changed: 34 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -20,23 +20,27 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr
2020
bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * scr0_ne, const int src1_ncols, bool mul_mat_id);
2121

2222
template <typename T, int rows_per_block, int cols_per_block, int nwarps, bool has_ids>
23-
static __device__ void mul_mat_f_impl(
23+
__launch_bounds__(ggml_cuda_get_physical_warp_size()*nwarps, 1)
24+
static __global__ void mul_mat_f(
2425
const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst,
2526
const int ncols, const int ncols_dst_total, const int nchannels_dst, const int stride_row, const int stride_col_y, const int stride_col_dst,
2627
const int stride_col_id, const int stride_row_id,
2728
const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
2829
const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
2930
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
30-
#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
31-
typedef tile<32, 8, T> tile_A;
32-
typedef tile< 8, 8, T> tile_B;
33-
typedef tile<32, 8, float> tile_C;
34-
#else
35-
// In principle also possible to use tiles with I == 32, the performance difference is ~1%.
36-
typedef tile<16, 8, T> tile_A;
37-
typedef tile< 8, 8, T> tile_B;
38-
typedef tile<16, 8, float> tile_C;
39-
#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
31+
constexpr bool I_16_supported = tile<16, 8, T>::supported() && tile<16, 8, float>::supported();
32+
constexpr bool I_32_supported = tile<32, 8, T>::supported() && tile<32, 8, float>::supported();
33+
34+
if (!I_16_supported && !I_32_supported) {
35+
NO_DEVICE_CODE;
36+
return;
37+
}
38+
39+
constexpr int I_preferred = I_16_supported ? 16 : 32; // For Turing MMA both work butr 16 is ~1% faster.
40+
41+
typedef tile<I_preferred, 8, T> tile_A;
42+
typedef tile<8, 8, T> tile_B;
43+
typedef tile<I_preferred, 8, float> tile_C;
4044

4145
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
4246
constexpr int tile_k_padded = warp_size + 4;
@@ -238,43 +242,10 @@ static __device__ void mul_mat_f_impl(
238242
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
239243
}
240244

241-
template <typename T, int rows_per_block, int cols_per_block, int nwarps, bool has_ids>
242-
__launch_bounds__(ggml_cuda_get_physical_warp_size()*nwarps, 1)
243-
static __global__ void mul_mat_f(
244-
const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst,
245-
const int ncols, const int ncols_dst_total, const int nchannels_dst, const int stride_row, const int stride_col_y, const int stride_col_dst,
246-
const int stride_col_id, const int stride_row_id,
247-
const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
248-
const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
249-
if constexpr (std::is_same_v<T, half2>) {
250-
#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
251-
mul_mat_f_impl<T, rows_per_block, cols_per_block, nwarps, has_ids>(
252-
x, y, ids, dst, ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y,
253-
stride_col_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y,
254-
stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
255-
#else
256-
NO_DEVICE_CODE;
257-
#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
258-
} else if constexpr (std::is_same_v<T, float> || std::is_same_v<T, nv_bfloat162>) {
259-
#ifdef AMPERE_MMA_AVAILABLE
260-
mul_mat_f_impl<T, rows_per_block, cols_per_block, nwarps, has_ids>(
261-
x, y, ids, dst, ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y,
262-
stride_col_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y,
263-
stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
264-
#else
265-
NO_DEVICE_CODE;
266-
#endif // AMPERE_MMA_AVAILABLE
267-
} else {
268-
static_assert(std::is_same_v<T, void>, "bad type");
269-
}
270-
GGML_UNUSED_VARS(x, y, ids, dst, ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y,
271-
stride_col_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y,
272-
stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
273-
}
274-
275245
//This kernel is for larger batch sizes of mul_mat_id
276246
template <typename T, int rows_per_block, int cols_per_block, int nwarps>
277-
static __device__ void mul_mat_f_ids_impl(
247+
__launch_bounds__(ggml_cuda_get_physical_warp_size()*nwarps, 1)
248+
static __global__ void mul_mat_f_ids(
278249
const T * __restrict__ x, const float * __restrict__ y,
279250
const int32_t * __restrict__ ids_src_compact, const int32_t * __restrict__ ids_dst_compact,
280251
const int32_t * __restrict__ expert_bounds, float * __restrict__ dst,
@@ -283,16 +254,19 @@ static __device__ void mul_mat_f_ids_impl(
283254
const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
284255
const uint3 sis1_fd, const uint3 nch_fd) {
285256
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
286-
#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
287-
typedef tile<32, 8, T> tile_A;
288-
typedef tile< 8, 8, T> tile_B;
289-
typedef tile<32, 8, float> tile_C;
290-
#else
291-
// In principle also possible to use tiles with I == 32, the performance difference is ~1%.
292-
typedef tile<16, 8, T> tile_A;
293-
typedef tile< 8, 8, T> tile_B;
294-
typedef tile<16, 8, float> tile_C;
295-
#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
257+
constexpr bool I_16_supported = tile<16, 8, T>::supported() && tile<16, 8, float>::supported();
258+
constexpr bool I_32_supported = tile<32, 8, T>::supported() && tile<32, 8, float>::supported();
259+
260+
if (!I_16_supported && !I_32_supported) {
261+
NO_DEVICE_CODE;
262+
return;
263+
}
264+
265+
constexpr int I_preferred = I_16_supported ? 16 : 32; // For Turing MMA both work butr 16 is ~1% faster.
266+
267+
typedef tile<I_preferred, 8, T> tile_A;
268+
typedef tile<8, 8, T> tile_B;
269+
typedef tile<I_preferred, 8, float> tile_C;
296270

297271
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
298272
constexpr int tile_k_padded = warp_size + 4;
@@ -521,46 +495,6 @@ static __device__ void mul_mat_f_ids_impl(
521495
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
522496
}
523497

524-
template <typename T, int rows_per_block, int cols_per_block, int nwarps>
525-
__launch_bounds__(ggml_cuda_get_physical_warp_size()*nwarps, 1)
526-
static __global__ void mul_mat_f_ids(
527-
const T * __restrict__ x, const float * __restrict__ y,
528-
const int32_t * __restrict__ ids_src_compact, const int32_t * __restrict__ ids_dst_compact,
529-
const int32_t * __restrict__ expert_bounds, float * __restrict__ dst,
530-
const int ncols, const int ncols_dst_total, const int nchannels_dst, const int stride_row, const int stride_col_y, const int stride_col_dst,
531-
const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
532-
const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
533-
const uint3 sis1_fd, const uint3 nch_fd) {
534-
if constexpr (std::is_same_v<T, half2>) {
535-
#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
536-
mul_mat_f_ids_impl<T, rows_per_block, cols_per_block, nwarps>(
537-
x, y, ids_src_compact, ids_dst_compact, expert_bounds, dst,
538-
ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
539-
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
540-
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, sis1_fd, nch_fd);
541-
#else
542-
NO_DEVICE_CODE;
543-
#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
544-
} else if constexpr (std::is_same_v<T, float> || std::is_same_v<T, nv_bfloat162>) {
545-
#ifdef AMPERE_MMA_AVAILABLE
546-
mul_mat_f_ids_impl<T, rows_per_block, cols_per_block, nwarps>(
547-
x, y, ids_src_compact, ids_dst_compact, expert_bounds, dst,
548-
ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
549-
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
550-
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, sis1_fd, nch_fd);
551-
#else
552-
NO_DEVICE_CODE;
553-
#endif // AMPERE_MMA_AVAILABLE
554-
} else {
555-
static_assert(std::is_same_v<T, void>, "bad type");
556-
}
557-
GGML_UNUSED_VARS(
558-
x, y, ids_src_compact, ids_dst_compact, expert_bounds, dst,
559-
ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
560-
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
561-
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, sis1_fd, nch_fd);
562-
}
563-
564498
template<typename T, int cols_per_block, int nwarps>
565499
static inline void mul_mat_f_switch_ids(
566500
const T * x, const float * y, const int32_t * ids, float * dst,
@@ -618,7 +552,7 @@ void mul_mat_f_cuda(
618552
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
619553
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
620554
cudaStream_t stream, const mmf_ids_data * ids_data) {
621-
typedef tile<32, 8, T> tile_A_16;
555+
typedef tile<16, 8, T> tile_A_16;
622556
typedef tile<32, 8, T> tile_A_32;
623557
typedef tile< 8, 8, T> tile_B;
624558

@@ -630,7 +564,8 @@ void mul_mat_f_cuda(
630564
const int64_t channel_ratio = nchannels_dst / nchannels_x;
631565
const int64_t sample_ratio = nsamples_dst / nsamples_x;
632566

633-
const int device = ggml_cuda_get_device();
567+
const int device = ggml_cuda_get_device();
568+
const int cc = ggml_cuda_info().devices[device].cc;
634569
const int warp_size = ggml_cuda_info().devices[device].warp_size;
635570

636571
int64_t nwarps_best = 1;
@@ -645,7 +580,7 @@ void mul_mat_f_cuda(
645580
}
646581

647582
constexpr int rows_per_block = MMF_ROWS_PER_BLOCK;
648-
const int nbytes_shared_iter = nwarps_best * (volta_mma_available ? tile_A_32::I : tile_A_16::I) * (warp_size + 4) * 4;
583+
const int nbytes_shared_iter = nwarps_best * (volta_mma_available(cc) ? tile_A_32::I : tile_A_16::I) * (warp_size + 4) * 4;
649584
const int nbytes_shared_combine = GGML_PAD(cols_per_block, tile_B::I) * (nwarps_best*rows_per_block + 4) * 4;
650585
const int nbytes_shared = std::max(nbytes_shared_iter, nbytes_shared_combine);
651586
const int nbytes_slotmap = ids ? GGML_PAD(cols_per_block, 16) * sizeof(int) : 0;

0 commit comments

Comments
 (0)