3
3
4
4
#include < vector>
5
5
6
+ // To reduce shared memory use, store "it" and "iex_used" with 22/10 bits each.
7
+ struct mmq_ids_helper_store {
8
+ uint32_t data;
9
+
10
+ __device__ mmq_ids_helper_store (const uint32_t it, const uint32_t iex_used) {
11
+ data = (it & 0x003FFFFF ) | (iex_used << 22 );
12
+ }
13
+
14
+ __device__ uint32_t it () const {
15
+ return data & 0x003FFFFF ;
16
+ }
17
+
18
+ __device__ uint32_t iex_used () const {
19
+ return data >> 22 ;
20
+ }
21
+ };
22
+ static_assert (sizeof (mmq_ids_helper_store) == 4, "unexpected size for mmq_ids_helper_store");
23
+
24
+ // Helper function for mul_mat_id, converts ids to a more convenient format.
25
+ // ids_src1 describes how to permute the flattened column indices of src1 in order to get a compact src1 tensor sorted by expert.
26
+ // ids_dst describes the same mapping but for the dst tensor.
27
+ // The upper and lower bounds for the ith expert in the compact src1 tensor are stored in expert_bounds[i:i+1].
28
+ template <int n_expert_used_template>
29
+ __launch_bounds__ (ggml_cuda_get_physical_warp_size(), 1)
30
+ static __global__ void mmq_ids_helper(
31
+ const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds,
32
+ const int n_tokens, const int n_expert_used_var, const int nchannels_y, const int si1, const int sis1) {
33
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size ();
34
+ const int n_expert_used = n_expert_used_template == 0 ? n_expert_used_var : n_expert_used_template;
35
+ const int expert = blockIdx .x ;
36
+
37
+ extern __shared__ char data_mmq_ids_helper[];
38
+ mmq_ids_helper_store * store = (mmq_ids_helper_store *) data_mmq_ids_helper;
39
+
40
+ int nex_prev = 0 ; // Number of columns for experts with a lower index.
41
+ int it_compact = 0 ; // Running index for the compact slice of this expert.
42
+
43
+ if constexpr (n_expert_used_template == 0 ) {
44
+ // Generic implementation:
45
+ for (int it = 0 ; it < n_tokens; ++it) {
46
+ int iex_used = -1 ; // The index at which the expert is used, if any.
47
+ for (int iex = threadIdx .x ; iex < n_expert_used; iex += warp_size) {
48
+ const int expert_used = ids[it*si1 + iex];
49
+ nex_prev += expert_used < expert;
50
+ if (expert_used == expert) {
51
+ iex_used = iex;
52
+ }
53
+ }
54
+
55
+ if (iex_used != -1 ) {
56
+ store[it_compact] = mmq_ids_helper_store (it, iex_used);
57
+ }
58
+
59
+ if (warp_reduce_any<warp_size>(iex_used != -1 )) {
60
+ it_compact++;
61
+ }
62
+ }
63
+ } else {
64
+ // Implementation optimized for specific numbers of experts used:
65
+ static_assert (n_expert_used == 6 || warp_size % n_expert_used == 0 , " bad n_expert_used" );
66
+ const int neu_padded = n_expert_used == 6 ? 8 : n_expert_used; // Padded to next higher power of 2.
67
+ for (int it0 = 0 ; it0 < n_tokens; it0 += warp_size/neu_padded) {
68
+ const int it = it0 + threadIdx .x / neu_padded;
69
+
70
+ const int iex = threadIdx .x % neu_padded; // The index at which the expert is used, if any.
71
+ const int expert_used = (neu_padded == n_expert_used || iex < n_expert_used) && it < n_tokens ?
72
+ ids[it*si1 + iex] : INT_MAX;
73
+ const int iex_used = expert_used == expert ? iex : -1 ;
74
+ nex_prev += expert_used < expert;
75
+
76
+ // Whether the threads at this token position have used the expert:
77
+ const int it_compact_add_self = warp_reduce_any<neu_padded>(iex_used != -1 );
78
+
79
+ // Do a scan over threads at lower token positions in warp to get the correct index for writing data:
80
+ int it_compact_add_lower = 0 ;
81
+ #pragma unroll
82
+ for (int offset = neu_padded; offset < warp_size; offset += neu_padded) {
83
+ const int tmp = __shfl_up_sync (0xFFFFFFFF , it_compact_add_self, offset, warp_size);
84
+ if (threadIdx .x >= offset) {
85
+ it_compact_add_lower += tmp;
86
+ }
87
+ }
88
+
89
+ if (iex_used != -1 ) {
90
+ store[it_compact + it_compact_add_lower] = mmq_ids_helper_store (it, iex_used);
91
+ }
92
+
93
+ // The thread with the highest index in the warp always has the sum over the whole warp, use it to increment all threads:
94
+ it_compact += __shfl_sync (0xFFFFFFFF , it_compact_add_lower + it_compact_add_self, warp_size - 1 , warp_size);
95
+ }
96
+ }
97
+ nex_prev = warp_reduce_sum<warp_size>(nex_prev);
98
+
99
+ for (int itc = threadIdx .x ; itc < it_compact; itc += warp_size) {
100
+ const mmq_ids_helper_store store_it = store[itc];
101
+ const int it = store_it.it ();
102
+ const int iex_used = store_it.iex_used ();
103
+ ids_src1[nex_prev + itc] = it*sis1 + iex_used % nchannels_y;
104
+ ids_dst [nex_prev + itc] = it*n_expert_used + iex_used;
105
+ }
106
+
107
+ if (threadIdx .x != 0 ) {
108
+ return ;
109
+ }
110
+
111
+ expert_bounds[expert] = nex_prev;
112
+
113
+ if (expert < gridDim .x - 1 ) {
114
+ return ;
115
+ }
116
+
117
+ expert_bounds[gridDim .x ] = nex_prev + it_compact;
118
+ }
119
+
120
+ template <int n_expert_used_template>
121
+ static void launch_mmq_ids_helper (
122
+ const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds,
123
+ const int n_experts, const int n_tokens, const int n_expert_used_var, const int nchannels_y, const int si1, const int sis1, cudaStream_t stream) {
124
+ GGML_ASSERT (n_tokens < (1 << 22 ) && " too few bits in mmq_ids_helper_store" );
125
+ GGML_ASSERT (n_expert_used_var < (1 << 10 ) && " too few bits in mmq_ids_helper_store" );
126
+
127
+ const int id = ggml_cuda_get_device ();
128
+ const int warp_size = ggml_cuda_info ().devices [id].warp_size ;
129
+ const size_t smpbo = ggml_cuda_info ().devices [id].smpbo ;
130
+ CUDA_SET_SHARED_MEMORY_LIMIT (mmq_ids_helper<n_expert_used_template>, smpbo);
131
+
132
+ const dim3 num_blocks (n_experts, 1 , 1 );
133
+ const dim3 block_size (warp_size, 1 , 1 );
134
+ const size_t nbytes_shared = n_tokens*sizeof (mmq_ids_helper_store);
135
+ GGML_ASSERT (nbytes_shared <= smpbo);
136
+ mmq_ids_helper<n_expert_used_template><<<num_blocks, block_size, nbytes_shared, stream>>>
137
+ (ids, ids_src1, ids_dst, expert_bounds, n_tokens, n_expert_used_var, nchannels_y, si1, sis1);
138
+ }
139
+
6
140
static void ggml_cuda_mul_mat_q_switch_type (ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) {
7
141
switch (args.type_x ) {
8
142
case GGML_TYPE_Q4_0:
@@ -137,7 +271,7 @@ void ggml_cuda_mul_mat_q(
137
271
ne00, ne01, ne1, s01, ne11, s1,
138
272
ne02, ne12, s02, s12, s2,
139
273
ne03, ne13, s03, s13, s3,
140
- use_stream_k};
274
+ use_stream_k, ne1 };
141
275
ggml_cuda_mul_mat_q_switch_type (ctx, args, stream);
142
276
return ;
143
277
}
@@ -148,53 +282,49 @@ void ggml_cuda_mul_mat_q(
148
282
149
283
const int64_t n_expert_used = ids->ne [0 ];
150
284
const int64_t ne_get_rows = ne12 * n_expert_used;
285
+ GGML_ASSERT (ne1 == n_expert_used);
151
286
152
- std::vector<char > ids_host (ggml_nbytes (ids));
153
- std::vector<int32_t > ids_src1_host;
154
- ids_src1_host.reserve (ne_get_rows);
155
- std::vector<int32_t > ids_dst_host;
156
- ids_dst_host.reserve (ne_get_rows);
157
- std::vector<int32_t > tokens_per_expert_host (ne02);
158
- std::vector<int32_t > expert_bounds_host (ne02 + 1 );
159
- ggml_cuda_pool_alloc<int32_t > ids_buf_dev (ctx.pool ());
160
-
161
- CUDA_CHECK (cudaMemcpyAsync (ids_host.data (), ids->data , ggml_nbytes (ids), cudaMemcpyDeviceToHost, stream));
162
- CUDA_CHECK (cudaStreamSynchronize (stream));
163
-
164
- for (int64_t i02 = 0 ; i02 < ne02; ++i02) { // expert matrices
165
- for (int64_t i12 = 0 ; i12 < ne12; ++i12) { // tokens
166
- for (int64_t iex = 0 ; iex < n_expert_used; ++iex) {
167
- const int32_t expert_to_use = *(const int32_t *)(ids_host.data () + i12*ids->nb [1 ] + iex*ids->nb [0 ]);
168
- assert (expert_to_use >= 0 && expert_to_use < ne02);
169
- if (expert_to_use == i02) {
170
- ids_src1_host.push_back (i12*(nb12/nb11) + iex % ne11);
171
- ids_dst_host.push_back (i12*ne1 + iex);
172
- tokens_per_expert_host[i02]++;
173
- break ;
174
- }
175
- }
176
- }
177
- }
287
+ ggml_cuda_pool_alloc<int32_t > ids_src1 (ctx.pool (), ne_get_rows);
288
+ ggml_cuda_pool_alloc<int32_t > ids_dst (ctx.pool (), ne_get_rows);
289
+ ggml_cuda_pool_alloc<int32_t > expert_bounds (ctx.pool (), ne02 + 1 );
178
290
179
- int32_t cumsum = 0 ;
180
- for (int64_t i = 0 ; i < ne02; ++i) {
181
- expert_bounds_host[i] = cumsum;
182
- cumsum += tokens_per_expert_host[i];
291
+ {
292
+ GGML_ASSERT (ids->nb [0 ] == ggml_element_size (ids));
293
+ const int si1 = ids->nb [1 ] / ggml_element_size (ids);
294
+ const int sis1 = nb12 / nb11;
295
+
296
+ switch (n_expert_used) {
297
+ case 2 :
298
+ launch_mmq_ids_helper< 2 > ((const int32_t *) ids->data , ids_src1.get (), ids_dst.get (), expert_bounds.get (),
299
+ ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
300
+ break ;
301
+ case 4 :
302
+ launch_mmq_ids_helper< 4 > ((const int32_t *) ids->data , ids_src1.get (), ids_dst.get (), expert_bounds.get (),
303
+ ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
304
+ break ;
305
+ case 6 :
306
+ launch_mmq_ids_helper< 6 > ((const int32_t *) ids->data , ids_src1.get (), ids_dst.get (), expert_bounds.get (),
307
+ ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
308
+ break ;
309
+ case 8 :
310
+ launch_mmq_ids_helper< 8 > ((const int32_t *) ids->data , ids_src1.get (), ids_dst.get (), expert_bounds.get (),
311
+ ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
312
+ break ;
313
+ case 16 :
314
+ launch_mmq_ids_helper<16 > ((const int32_t *) ids->data , ids_src1.get (), ids_dst.get (), expert_bounds.get (),
315
+ ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
316
+ break ;
317
+ case 32 :
318
+ launch_mmq_ids_helper<32 > ((const int32_t *) ids->data , ids_src1.get (), ids_dst.get (), expert_bounds.get (),
319
+ ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
320
+ break ;
321
+ default :
322
+ launch_mmq_ids_helper< 0 > ((const int32_t *) ids->data , ids_src1.get (), ids_dst.get (), expert_bounds.get (),
323
+ ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
324
+ break ;
325
+ }
326
+ CUDA_CHECK (cudaGetLastError ());
183
327
}
184
- expert_bounds_host[ne02] = cumsum;
185
-
186
- std::vector<int32_t > ids_buf_host;
187
- ids_buf_host.reserve (ids_src1_host.size () + ids_dst_host.size () + expert_bounds_host.size ());
188
- ids_buf_host.insert (ids_buf_host.end (), ids_src1_host.begin (), ids_src1_host.end ());
189
- ids_buf_host.insert (ids_buf_host.end (), ids_dst_host.begin (), ids_dst_host.end ());
190
- ids_buf_host.insert (ids_buf_host.end (), expert_bounds_host.begin (), expert_bounds_host.end ());
191
- ids_buf_dev.alloc (ids_buf_host.size () + get_mmq_x_max_host (cc)); // Expert bounds are padded on device.
192
- CUDA_CHECK (cudaMemcpyAsync (ids_buf_dev.ptr , ids_buf_host.data (), ids_buf_host.size ()*sizeof (int32_t ), cudaMemcpyHostToDevice, stream));
193
- CUDA_CHECK (cudaStreamSynchronize (stream));
194
-
195
- const int32_t * ids_src1_dev = ids_buf_dev.ptr ;
196
- const int32_t * ids_dst_dev = ids_src1_dev + ids_src1_host.size ();
197
- const int32_t * expert_bounds_dev = ids_dst_dev + ids_dst_host.size ();
198
328
199
329
const size_t nbytes_src1_q8_1 = ne12*n_expert_used*ne10_padded * sizeof (block_q8_1)/QK8_1 +
200
330
get_mmq_x_max_host (cc)*sizeof (block_q8_1_mmq);
@@ -208,7 +338,7 @@ void ggml_cuda_mul_mat_q(
208
338
const int64_t s11 = src1->nb [1 ] / ts_src1;
209
339
const int64_t s12 = src1->nb [2 ] / ts_src1;
210
340
const int64_t s13 = src1->nb [2 ] / ts_src1;
211
- quantize_mmq_q8_1_cuda (src1_d, ids_src1_dev , src1_q8_1.get (), src0->type ,
341
+ quantize_mmq_q8_1_cuda (src1_d, ids_src1. get () , src1_q8_1.get (), src0->type ,
212
342
ne10, s11, s12, s13, ne10_padded, ne11_flat, ne12_flat, ne13_flat, stream);
213
343
CUDA_CHECK (cudaGetLastError ());
214
344
}
@@ -218,11 +348,11 @@ void ggml_cuda_mul_mat_q(
218
348
219
349
// Note that ne02 is used instead of ne12 because the number of y channels determines the z dimension of the CUDA grid.
220
350
const mmq_args args = {
221
- src0_d, src0->type , (const int *) src1_q8_1.ptr , ids_dst_dev, expert_bounds_dev , dst_d,
351
+ src0_d, src0->type , (const int *) src1_q8_1.get (), ids_dst. get (), expert_bounds. get () , dst_d,
222
352
ne00, ne01, ne_get_rows, s01, ne_get_rows, s1,
223
353
ne02, ne02, s02, s12, s2,
224
354
ne03, ne13, s03, s13, s3,
225
- use_stream_k};
355
+ use_stream_k, ne12 };
226
356
227
357
ggml_cuda_mul_mat_q_switch_type (ctx, args, stream);
228
358
}
@@ -262,7 +392,7 @@ void ggml_cuda_op_mul_mat_q(
262
392
ne00, row_diff, src1_ncols, stride01, ne11, nrows_dst,
263
393
1 , 1 , 0 , 0 , 0 ,
264
394
1 , 1 , 0 , 0 , 0 ,
265
- use_stream_k};
395
+ use_stream_k, src1_ncols };
266
396
267
397
ggml_cuda_mul_mat_q_switch_type (ctx, args, stream);
268
398
0 commit comments