Skip to content

Commit 6888c0d

Browse files
committed
Revert "CUDA: MoE helper in device code, better tile sizes (ggml-org#15525)"
This reverts commit 5eff6ec.
1 parent 78dc93d commit 6888c0d

File tree

4 files changed

+70
-223
lines changed

4 files changed

+70
-223
lines changed

ggml/src/ggml-cuda/common.cuh

Lines changed: 8 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -424,28 +424,16 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
424424

425425
template<int width = WARP_SIZE>
426426
static __device__ __forceinline__ int warp_reduce_all(int x) {
427-
if (width == ggml_cuda_get_physical_warp_size()) {
428-
return __all_sync(0xffffffff, x);
429-
} else {
430-
#pragma unroll
431-
for (int offset = width/2; offset > 0; offset >>= 1) {
432-
x = __shfl_xor_sync(0xffffffff, x, offset, width) && x;
433-
}
434-
return x;
435-
}
436-
}
437-
438-
template<int width = WARP_SIZE>
439-
static __device__ __forceinline__ int warp_reduce_any(int x) {
440-
if (width == ggml_cuda_get_physical_warp_size()) {
441-
return __any_sync(0xffffffff, x);
442-
} else {
427+
#ifdef GGML_USE_HIP
443428
#pragma unroll
444-
for (int offset = width/2; offset > 0; offset >>= 1) {
445-
x = __shfl_xor_sync(0xffffffff, x, offset, width) || x;
446-
}
447-
return x;
429+
for (int offset = width/2; offset > 0; offset >>= 1) {
430+
x = x && __shfl_xor_sync(0xffffffff, x, offset, width);
448431
}
432+
return x;
433+
#else
434+
static_assert(width == WARP_SIZE, "width != WARP_SIZE not implemented");
435+
return __all_sync(0xffffffff, x);
436+
#endif // GGML_USE_HIP
449437
}
450438

451439
template<int width = WARP_SIZE>

ggml/src/ggml-cuda/mmq.cu

Lines changed: 49 additions & 179 deletions
Original file line numberDiff line numberDiff line change
@@ -3,140 +3,6 @@
33

44
#include <vector>
55

6-
// To reduce shared memory use, store "it" and "iex_used" with 22/10 bits each.
7-
struct mmq_ids_helper_store {
8-
uint32_t data;
9-
10-
__device__ mmq_ids_helper_store(const uint32_t it, const uint32_t iex_used) {
11-
data = (it & 0x003FFFFF) | (iex_used << 22);
12-
}
13-
14-
__device__ uint32_t it() const {
15-
return data & 0x003FFFFF;
16-
}
17-
18-
__device__ uint32_t iex_used() const {
19-
return data >> 22;
20-
}
21-
};
22-
static_assert(sizeof(mmq_ids_helper_store) == 4, "unexpected size for mmq_ids_helper_store");
23-
24-
// Helper function for mul_mat_id, converts ids to a more convenient format.
25-
// ids_src1 describes how to permute the flattened column indices of src1 in order to get a compact src1 tensor sorted by expert.
26-
// ids_dst describes the same mapping but for the dst tensor.
27-
// The upper and lower bounds for the ith expert in the compact src1 tensor are stored in expert_bounds[i:i+1].
28-
template <int n_expert_used_template>
29-
__launch_bounds__(ggml_cuda_get_physical_warp_size(), 1)
30-
static __global__ void mmq_ids_helper(
31-
const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds,
32-
const int n_tokens, const int n_expert_used_var, const int nchannels_y, const int si1, const int sis1) {
33-
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
34-
const int n_expert_used = n_expert_used_template == 0 ? n_expert_used_var : n_expert_used_template;
35-
const int expert = blockIdx.x;
36-
37-
extern __shared__ char data_mmq_ids_helper[];
38-
mmq_ids_helper_store * store = (mmq_ids_helper_store *) data_mmq_ids_helper;
39-
40-
int nex_prev = 0; // Number of columns for experts with a lower index.
41-
int it_compact = 0; // Running index for the compact slice of this expert.
42-
43-
if constexpr (n_expert_used_template == 0) {
44-
// Generic implementation:
45-
for (int it = 0; it < n_tokens; ++it) {
46-
int iex_used = -1; // The index at which the expert is used, if any.
47-
for (int iex = threadIdx.x; iex < n_expert_used; iex += warp_size) {
48-
const int expert_used = ids[it*si1 + iex];
49-
nex_prev += expert_used < expert;
50-
if (expert_used == expert) {
51-
iex_used = iex;
52-
}
53-
}
54-
55-
if (iex_used != -1) {
56-
store[it_compact] = mmq_ids_helper_store(it, iex_used);
57-
}
58-
59-
if (warp_reduce_any<warp_size>(iex_used != -1)) {
60-
it_compact++;
61-
}
62-
}
63-
} else {
64-
// Implementation optimized for specific numbers of experts used:
65-
static_assert(n_expert_used == 6 || warp_size % n_expert_used == 0, "bad n_expert_used");
66-
const int neu_padded = n_expert_used == 6 ? 8 : n_expert_used; // Padded to next higher power of 2.
67-
for (int it0 = 0; it0 < n_tokens; it0 += warp_size/neu_padded) {
68-
const int it = it0 + threadIdx.x / neu_padded;
69-
70-
const int iex = threadIdx.x % neu_padded; // The index at which the expert is used, if any.
71-
const int expert_used = (neu_padded == n_expert_used || iex < n_expert_used) && it < n_tokens ?
72-
ids[it*si1 + iex] : INT_MAX;
73-
const int iex_used = expert_used == expert ? iex : -1;
74-
nex_prev += expert_used < expert;
75-
76-
// Whether the threads at this token position have used the expert:
77-
const int it_compact_add_self = warp_reduce_any<neu_padded>(iex_used != -1);
78-
79-
// Do a scan over threads at lower token positions in warp to get the correct index for writing data:
80-
int it_compact_add_lower = 0;
81-
#pragma unroll
82-
for (int offset = neu_padded; offset < warp_size; offset += neu_padded) {
83-
const int tmp = __shfl_up_sync(0xFFFFFFFF, it_compact_add_self, offset, warp_size);
84-
if (threadIdx.x >= offset) {
85-
it_compact_add_lower += tmp;
86-
}
87-
}
88-
89-
if (iex_used != -1) {
90-
store[it_compact + it_compact_add_lower] = mmq_ids_helper_store(it, iex_used);
91-
}
92-
93-
// The thread with the highest index in the warp always has the sum over the whole warp, use it to increment all threads:
94-
it_compact += __shfl_sync(0xFFFFFFFF, it_compact_add_lower + it_compact_add_self, warp_size - 1, warp_size);
95-
}
96-
}
97-
nex_prev = warp_reduce_sum<warp_size>(nex_prev);
98-
99-
for (int itc = threadIdx.x; itc < it_compact; itc += warp_size) {
100-
const mmq_ids_helper_store store_it = store[itc];
101-
const int it = store_it.it();
102-
const int iex_used = store_it.iex_used();
103-
ids_src1[nex_prev + itc] = it*sis1 + iex_used % nchannels_y;
104-
ids_dst [nex_prev + itc] = it*n_expert_used + iex_used;
105-
}
106-
107-
if (threadIdx.x != 0) {
108-
return;
109-
}
110-
111-
expert_bounds[expert] = nex_prev;
112-
113-
if (expert < gridDim.x - 1) {
114-
return;
115-
}
116-
117-
expert_bounds[gridDim.x] = nex_prev + it_compact;
118-
}
119-
120-
template <int n_expert_used_template>
121-
static void launch_mmq_ids_helper(
122-
const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds,
123-
const int n_experts, const int n_tokens, const int n_expert_used_var, const int nchannels_y, const int si1, const int sis1, cudaStream_t stream) {
124-
GGML_ASSERT(n_tokens < (1 << 22) && "too few bits in mmq_ids_helper_store");
125-
GGML_ASSERT(n_expert_used_var < (1 << 10) && "too few bits in mmq_ids_helper_store");
126-
127-
const int id = ggml_cuda_get_device();
128-
const int warp_size = ggml_cuda_info().devices[id].warp_size;
129-
const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
130-
CUDA_SET_SHARED_MEMORY_LIMIT(mmq_ids_helper<n_expert_used_template>, smpbo);
131-
132-
const dim3 num_blocks(n_experts, 1, 1);
133-
const dim3 block_size(warp_size, 1, 1);
134-
const size_t nbytes_shared = n_tokens*sizeof(mmq_ids_helper_store);
135-
GGML_ASSERT(nbytes_shared <= smpbo);
136-
mmq_ids_helper<n_expert_used_template><<<num_blocks, block_size, nbytes_shared, stream>>>
137-
(ids, ids_src1, ids_dst, expert_bounds, n_tokens, n_expert_used_var, nchannels_y, si1, sis1);
138-
}
139-
1406
static void ggml_cuda_mul_mat_q_switch_type(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) {
1417
switch (args.type_x) {
1428
case GGML_TYPE_Q4_0:
@@ -271,7 +137,7 @@ void ggml_cuda_mul_mat_q(
271137
ne00, ne01, ne1, s01, ne11, s1,
272138
ne02, ne12, s02, s12, s2,
273139
ne03, ne13, s03, s13, s3,
274-
use_stream_k, ne1};
140+
use_stream_k};
275141
ggml_cuda_mul_mat_q_switch_type(ctx, args, stream);
276142
return;
277143
}
@@ -282,50 +148,54 @@ void ggml_cuda_mul_mat_q(
282148

283149
const int64_t n_expert_used = ids->ne[0];
284150
const int64_t ne_get_rows = ne12 * n_expert_used;
285-
GGML_ASSERT(ne1 == n_expert_used);
286-
287-
ggml_cuda_pool_alloc<int32_t> ids_src1(ctx.pool(), ne_get_rows);
288-
ggml_cuda_pool_alloc<int32_t> ids_dst(ctx.pool(), ne_get_rows);
289-
ggml_cuda_pool_alloc<int32_t> expert_bounds(ctx.pool(), ne02 + 1);
290151

291-
{
292-
GGML_ASSERT(ids->nb[0] == ggml_element_size(ids));
293-
const int si1 = ids->nb[1] / ggml_element_size(ids);
294-
const int sis1 = nb12 / nb11;
295-
296-
switch (n_expert_used) {
297-
case 2:
298-
launch_mmq_ids_helper< 2> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
299-
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
300-
break;
301-
case 4:
302-
launch_mmq_ids_helper< 4> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
303-
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
304-
break;
305-
case 6:
306-
launch_mmq_ids_helper< 6> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
307-
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
308-
break;
309-
case 8:
310-
launch_mmq_ids_helper< 8> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
311-
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
312-
break;
313-
case 16:
314-
launch_mmq_ids_helper<16> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
315-
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
316-
break;
317-
case 32:
318-
launch_mmq_ids_helper<32> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
319-
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
320-
break;
321-
default:
322-
launch_mmq_ids_helper< 0> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
323-
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
324-
break;
152+
std::vector<char> ids_host(ggml_nbytes(ids));
153+
std::vector<int32_t> ids_src1_host;
154+
ids_src1_host.reserve(ne_get_rows);
155+
std::vector<int32_t> ids_dst_host;
156+
ids_dst_host.reserve(ne_get_rows);
157+
std::vector<int32_t> tokens_per_expert_host(ne02);
158+
std::vector<int32_t> expert_bounds_host(ne02 + 1);
159+
ggml_cuda_pool_alloc<int32_t> ids_buf_dev(ctx.pool());
160+
161+
CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids->data, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream));
162+
CUDA_CHECK(cudaStreamSynchronize(stream));
163+
164+
for (int64_t i02 = 0; i02 < ne02; ++i02) { // expert matrices
165+
for (int64_t i12 = 0; i12 < ne12; ++i12) { // tokens
166+
for (int64_t iex = 0; iex < n_expert_used; ++iex) {
167+
const int32_t expert_to_use = *(const int32_t *)(ids_host.data() + i12*ids->nb[1] + iex*ids->nb[0]);
168+
assert(expert_to_use >= 0 && expert_to_use < ne02);
169+
if (expert_to_use == i02) {
170+
ids_src1_host.push_back(i12*(nb12/nb11) + iex % ne11);
171+
ids_dst_host.push_back(i12*ne1 + iex);
172+
tokens_per_expert_host[i02]++;
173+
break;
174+
}
175+
}
325176
}
326-
CUDA_CHECK(cudaGetLastError());
327177
}
328178

179+
int32_t cumsum = 0;
180+
for (int64_t i = 0; i < ne02; ++i) {
181+
expert_bounds_host[i] = cumsum;
182+
cumsum += tokens_per_expert_host[i];
183+
}
184+
expert_bounds_host[ne02] = cumsum;
185+
186+
std::vector<int32_t> ids_buf_host;
187+
ids_buf_host.reserve(ids_src1_host.size() + ids_dst_host.size() + expert_bounds_host.size());
188+
ids_buf_host.insert(ids_buf_host.end(), ids_src1_host.begin(), ids_src1_host.end());
189+
ids_buf_host.insert(ids_buf_host.end(), ids_dst_host.begin(), ids_dst_host.end());
190+
ids_buf_host.insert(ids_buf_host.end(), expert_bounds_host.begin(), expert_bounds_host.end());
191+
ids_buf_dev.alloc(ids_buf_host.size() + get_mmq_x_max_host(cc)); // Expert bounds are padded on device.
192+
CUDA_CHECK(cudaMemcpyAsync(ids_buf_dev.ptr, ids_buf_host.data(), ids_buf_host.size()*sizeof(int32_t), cudaMemcpyHostToDevice, stream));
193+
CUDA_CHECK(cudaStreamSynchronize(stream));
194+
195+
const int32_t * ids_src1_dev = ids_buf_dev.ptr;
196+
const int32_t * ids_dst_dev = ids_src1_dev + ids_src1_host.size();
197+
const int32_t * expert_bounds_dev = ids_dst_dev + ids_dst_host.size();
198+
329199
const size_t nbytes_src1_q8_1 = ne12*n_expert_used*ne10_padded * sizeof(block_q8_1)/QK8_1 +
330200
get_mmq_x_max_host(cc)*sizeof(block_q8_1_mmq);
331201
ggml_cuda_pool_alloc<char> src1_q8_1(ctx.pool(), nbytes_src1_q8_1);
@@ -338,7 +208,7 @@ void ggml_cuda_mul_mat_q(
338208
const int64_t s11 = src1->nb[1] / ts_src1;
339209
const int64_t s12 = src1->nb[2] / ts_src1;
340210
const int64_t s13 = src1->nb[2] / ts_src1;
341-
quantize_mmq_q8_1_cuda(src1_d, ids_src1.get(), src1_q8_1.get(), src0->type,
211+
quantize_mmq_q8_1_cuda(src1_d, ids_src1_dev, src1_q8_1.get(), src0->type,
342212
ne10, s11, s12, s13, ne10_padded, ne11_flat, ne12_flat, ne13_flat, stream);
343213
CUDA_CHECK(cudaGetLastError());
344214
}
@@ -348,11 +218,11 @@ void ggml_cuda_mul_mat_q(
348218

349219
// Note that ne02 is used instead of ne12 because the number of y channels determines the z dimension of the CUDA grid.
350220
const mmq_args args = {
351-
src0_d, src0->type, (const int *) src1_q8_1.get(), ids_dst.get(), expert_bounds.get(), dst_d,
221+
src0_d, src0->type, (const int *) src1_q8_1.ptr, ids_dst_dev, expert_bounds_dev, dst_d,
352222
ne00, ne01, ne_get_rows, s01, ne_get_rows, s1,
353223
ne02, ne02, s02, s12, s2,
354224
ne03, ne13, s03, s13, s3,
355-
use_stream_k, ne12};
225+
use_stream_k};
356226

357227
ggml_cuda_mul_mat_q_switch_type(ctx, args, stream);
358228
}
@@ -392,7 +262,7 @@ void ggml_cuda_op_mul_mat_q(
392262
ne00, row_diff, src1_ncols, stride01, ne11, nrows_dst,
393263
1, 1, 0, 0, 0,
394264
1, 1, 0, 0, 0,
395-
use_stream_k, src1_ncols};
265+
use_stream_k};
396266

397267
ggml_cuda_mul_mat_q_switch_type(ctx, args, stream);
398268

0 commit comments

Comments
 (0)