33
44#include < vector>
55
6- // To reduce shared memory use, store "it" and "iex_used" with 22/10 bits each.
7- struct mmq_ids_helper_store {
8- uint32_t data;
9-
10- __device__ mmq_ids_helper_store (const uint32_t it, const uint32_t iex_used) {
11- data = (it & 0x003FFFFF ) | (iex_used << 22 );
12- }
13-
14- __device__ uint32_t it () const {
15- return data & 0x003FFFFF ;
16- }
17-
18- __device__ uint32_t iex_used () const {
19- return data >> 22 ;
20- }
21- };
22- static_assert (sizeof (mmq_ids_helper_store) == 4 , " unexpected size for mmq_ids_helper_store" );
23-
24- // Helper function for mul_mat_id, converts ids to a more convenient format.
25- // ids_src1 describes how to permute the flattened column indices of src1 in order to get a compact src1 tensor sorted by expert.
26- // ids_dst describes the same mapping but for the dst tensor.
27- // The upper and lower bounds for the ith expert in the compact src1 tensor are stored in expert_bounds[i:i+1].
28- template <int n_expert_used_template>
29- __launch_bounds__ (ggml_cuda_get_physical_warp_size(), 1)
30- static __global__ void mmq_ids_helper(
31- const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds,
32- const int n_tokens, const int n_expert_used_var, const int nchannels_y, const int si1, const int sis1) {
33- constexpr int warp_size = ggml_cuda_get_physical_warp_size ();
34- const int n_expert_used = n_expert_used_template == 0 ? n_expert_used_var : n_expert_used_template;
35- const int expert = blockIdx .x ;
36-
37- extern __shared__ char data_mmq_ids_helper[];
38- mmq_ids_helper_store * store = (mmq_ids_helper_store *) data_mmq_ids_helper;
39-
40- int nex_prev = 0 ; // Number of columns for experts with a lower index.
41- int it_compact = 0 ; // Running index for the compact slice of this expert.
42-
43- if constexpr (n_expert_used_template == 0 ) {
44- // Generic implementation:
45- for (int it = 0 ; it < n_tokens; ++it) {
46- int iex_used = -1 ; // The index at which the expert is used, if any.
47- for (int iex = threadIdx .x ; iex < n_expert_used; iex += warp_size) {
48- const int expert_used = ids[it*si1 + iex];
49- nex_prev += expert_used < expert;
50- if (expert_used == expert) {
51- iex_used = iex;
52- }
53- }
54-
55- if (iex_used != -1 ) {
56- store[it_compact] = mmq_ids_helper_store (it, iex_used);
57- }
58-
59- if (warp_reduce_any<warp_size>(iex_used != -1 )) {
60- it_compact++;
61- }
62- }
63- } else {
64- // Implementation optimized for specific numbers of experts used:
65- static_assert (n_expert_used == 6 || warp_size % n_expert_used == 0 , " bad n_expert_used" );
66- const int neu_padded = n_expert_used == 6 ? 8 : n_expert_used; // Padded to next higher power of 2.
67- for (int it0 = 0 ; it0 < n_tokens; it0 += warp_size/neu_padded) {
68- const int it = it0 + threadIdx .x / neu_padded;
69-
70- const int iex = threadIdx .x % neu_padded; // The index at which the expert is used, if any.
71- const int expert_used = (neu_padded == n_expert_used || iex < n_expert_used) && it < n_tokens ?
72- ids[it*si1 + iex] : INT_MAX;
73- const int iex_used = expert_used == expert ? iex : -1 ;
74- nex_prev += expert_used < expert;
75-
76- // Whether the threads at this token position have used the expert:
77- const int it_compact_add_self = warp_reduce_any<neu_padded>(iex_used != -1 );
78-
79- // Do a scan over threads at lower token positions in warp to get the correct index for writing data:
80- int it_compact_add_lower = 0 ;
81- #pragma unroll
82- for (int offset = neu_padded; offset < warp_size; offset += neu_padded) {
83- const int tmp = __shfl_up_sync (0xFFFFFFFF , it_compact_add_self, offset, warp_size);
84- if (threadIdx .x >= offset) {
85- it_compact_add_lower += tmp;
86- }
87- }
88-
89- if (iex_used != -1 ) {
90- store[it_compact + it_compact_add_lower] = mmq_ids_helper_store (it, iex_used);
91- }
92-
93- // The thread with the highest index in the warp always has the sum over the whole warp, use it to increment all threads:
94- it_compact += __shfl_sync (0xFFFFFFFF , it_compact_add_lower + it_compact_add_self, warp_size - 1 , warp_size);
95- }
96- }
97- nex_prev = warp_reduce_sum<warp_size>(nex_prev);
98-
99- for (int itc = threadIdx .x ; itc < it_compact; itc += warp_size) {
100- const mmq_ids_helper_store store_it = store[itc];
101- const int it = store_it.it ();
102- const int iex_used = store_it.iex_used ();
103- ids_src1[nex_prev + itc] = it*sis1 + iex_used % nchannels_y;
104- ids_dst [nex_prev + itc] = it*n_expert_used + iex_used;
105- }
106-
107- if (threadIdx .x != 0 ) {
108- return ;
109- }
110-
111- expert_bounds[expert] = nex_prev;
112-
113- if (expert < gridDim .x - 1 ) {
114- return ;
115- }
116-
117- expert_bounds[gridDim .x ] = nex_prev + it_compact;
118- }
119-
120- template <int n_expert_used_template>
121- static void launch_mmq_ids_helper (
122- const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds,
123- const int n_experts, const int n_tokens, const int n_expert_used_var, const int nchannels_y, const int si1, const int sis1, cudaStream_t stream) {
124- GGML_ASSERT (n_tokens < (1 << 22 ) && " too few bits in mmq_ids_helper_store" );
125- GGML_ASSERT (n_expert_used_var < (1 << 10 ) && " too few bits in mmq_ids_helper_store" );
126-
127- const int id = ggml_cuda_get_device ();
128- const int warp_size = ggml_cuda_info ().devices [id].warp_size ;
129- const size_t smpbo = ggml_cuda_info ().devices [id].smpbo ;
130- CUDA_SET_SHARED_MEMORY_LIMIT (mmq_ids_helper<n_expert_used_template>, smpbo);
131-
132- const dim3 num_blocks (n_experts, 1 , 1 );
133- const dim3 block_size (warp_size, 1 , 1 );
134- const size_t nbytes_shared = n_tokens*sizeof (mmq_ids_helper_store);
135- GGML_ASSERT (nbytes_shared <= smpbo);
136- mmq_ids_helper<n_expert_used_template><<<num_blocks, block_size, nbytes_shared, stream>>>
137- (ids, ids_src1, ids_dst, expert_bounds, n_tokens, n_expert_used_var, nchannels_y, si1, sis1);
138- }
139-
1406static void ggml_cuda_mul_mat_q_switch_type (ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) {
1417 switch (args.type_x ) {
1428 case GGML_TYPE_Q4_0:
@@ -271,7 +137,7 @@ void ggml_cuda_mul_mat_q(
271137 ne00, ne01, ne1, s01, ne11, s1,
272138 ne02, ne12, s02, s12, s2,
273139 ne03, ne13, s03, s13, s3,
274- use_stream_k, ne1 };
140+ use_stream_k};
275141 ggml_cuda_mul_mat_q_switch_type (ctx, args, stream);
276142 return ;
277143 }
@@ -282,50 +148,54 @@ void ggml_cuda_mul_mat_q(
282148
283149 const int64_t n_expert_used = ids->ne [0 ];
284150 const int64_t ne_get_rows = ne12 * n_expert_used;
285- GGML_ASSERT (ne1 == n_expert_used);
286-
287- ggml_cuda_pool_alloc<int32_t > ids_src1 (ctx.pool (), ne_get_rows);
288- ggml_cuda_pool_alloc<int32_t > ids_dst (ctx.pool (), ne_get_rows);
289- ggml_cuda_pool_alloc<int32_t > expert_bounds (ctx.pool (), ne02 + 1 );
290151
291- {
292- GGML_ASSERT (ids->nb [0 ] == ggml_element_size (ids));
293- const int si1 = ids->nb [1 ] / ggml_element_size (ids);
294- const int sis1 = nb12 / nb11;
295-
296- switch (n_expert_used) {
297- case 2 :
298- launch_mmq_ids_helper< 2 > ((const int32_t *) ids->data , ids_src1.get (), ids_dst.get (), expert_bounds.get (),
299- ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
300- break ;
301- case 4 :
302- launch_mmq_ids_helper< 4 > ((const int32_t *) ids->data , ids_src1.get (), ids_dst.get (), expert_bounds.get (),
303- ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
304- break ;
305- case 6 :
306- launch_mmq_ids_helper< 6 > ((const int32_t *) ids->data , ids_src1.get (), ids_dst.get (), expert_bounds.get (),
307- ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
308- break ;
309- case 8 :
310- launch_mmq_ids_helper< 8 > ((const int32_t *) ids->data , ids_src1.get (), ids_dst.get (), expert_bounds.get (),
311- ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
312- break ;
313- case 16 :
314- launch_mmq_ids_helper<16 > ((const int32_t *) ids->data , ids_src1.get (), ids_dst.get (), expert_bounds.get (),
315- ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
316- break ;
317- case 32 :
318- launch_mmq_ids_helper<32 > ((const int32_t *) ids->data , ids_src1.get (), ids_dst.get (), expert_bounds.get (),
319- ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
320- break ;
321- default :
322- launch_mmq_ids_helper< 0 > ((const int32_t *) ids->data , ids_src1.get (), ids_dst.get (), expert_bounds.get (),
323- ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
324- break ;
152+ std::vector<char > ids_host (ggml_nbytes (ids));
153+ std::vector<int32_t > ids_src1_host;
154+ ids_src1_host.reserve (ne_get_rows);
155+ std::vector<int32_t > ids_dst_host;
156+ ids_dst_host.reserve (ne_get_rows);
157+ std::vector<int32_t > tokens_per_expert_host (ne02);
158+ std::vector<int32_t > expert_bounds_host (ne02 + 1 );
159+ ggml_cuda_pool_alloc<int32_t > ids_buf_dev (ctx.pool ());
160+
161+ CUDA_CHECK (cudaMemcpyAsync (ids_host.data (), ids->data , ggml_nbytes (ids), cudaMemcpyDeviceToHost, stream));
162+ CUDA_CHECK (cudaStreamSynchronize (stream));
163+
164+ for (int64_t i02 = 0 ; i02 < ne02; ++i02) { // expert matrices
165+ for (int64_t i12 = 0 ; i12 < ne12; ++i12) { // tokens
166+ for (int64_t iex = 0 ; iex < n_expert_used; ++iex) {
167+ const int32_t expert_to_use = *(const int32_t *)(ids_host.data () + i12*ids->nb [1 ] + iex*ids->nb [0 ]);
168+ assert (expert_to_use >= 0 && expert_to_use < ne02);
169+ if (expert_to_use == i02) {
170+ ids_src1_host.push_back (i12*(nb12/nb11) + iex % ne11);
171+ ids_dst_host.push_back (i12*ne1 + iex);
172+ tokens_per_expert_host[i02]++;
173+ break ;
174+ }
175+ }
325176 }
326- CUDA_CHECK (cudaGetLastError ());
327177 }
328178
179+ int32_t cumsum = 0 ;
180+ for (int64_t i = 0 ; i < ne02; ++i) {
181+ expert_bounds_host[i] = cumsum;
182+ cumsum += tokens_per_expert_host[i];
183+ }
184+ expert_bounds_host[ne02] = cumsum;
185+
186+ std::vector<int32_t > ids_buf_host;
187+ ids_buf_host.reserve (ids_src1_host.size () + ids_dst_host.size () + expert_bounds_host.size ());
188+ ids_buf_host.insert (ids_buf_host.end (), ids_src1_host.begin (), ids_src1_host.end ());
189+ ids_buf_host.insert (ids_buf_host.end (), ids_dst_host.begin (), ids_dst_host.end ());
190+ ids_buf_host.insert (ids_buf_host.end (), expert_bounds_host.begin (), expert_bounds_host.end ());
191+ ids_buf_dev.alloc (ids_buf_host.size () + get_mmq_x_max_host (cc)); // Expert bounds are padded on device.
192+ CUDA_CHECK (cudaMemcpyAsync (ids_buf_dev.ptr , ids_buf_host.data (), ids_buf_host.size ()*sizeof (int32_t ), cudaMemcpyHostToDevice, stream));
193+ CUDA_CHECK (cudaStreamSynchronize (stream));
194+
195+ const int32_t * ids_src1_dev = ids_buf_dev.ptr ;
196+ const int32_t * ids_dst_dev = ids_src1_dev + ids_src1_host.size ();
197+ const int32_t * expert_bounds_dev = ids_dst_dev + ids_dst_host.size ();
198+
329199 const size_t nbytes_src1_q8_1 = ne12*n_expert_used*ne10_padded * sizeof (block_q8_1)/QK8_1 +
330200 get_mmq_x_max_host (cc)*sizeof (block_q8_1_mmq);
331201 ggml_cuda_pool_alloc<char > src1_q8_1 (ctx.pool (), nbytes_src1_q8_1);
@@ -338,7 +208,7 @@ void ggml_cuda_mul_mat_q(
338208 const int64_t s11 = src1->nb [1 ] / ts_src1;
339209 const int64_t s12 = src1->nb [2 ] / ts_src1;
340210 const int64_t s13 = src1->nb [2 ] / ts_src1;
341- quantize_mmq_q8_1_cuda (src1_d, ids_src1. get () , src1_q8_1.get (), src0->type ,
211+ quantize_mmq_q8_1_cuda (src1_d, ids_src1_dev , src1_q8_1.get (), src0->type ,
342212 ne10, s11, s12, s13, ne10_padded, ne11_flat, ne12_flat, ne13_flat, stream);
343213 CUDA_CHECK (cudaGetLastError ());
344214 }
@@ -348,11 +218,11 @@ void ggml_cuda_mul_mat_q(
348218
349219 // Note that ne02 is used instead of ne12 because the number of y channels determines the z dimension of the CUDA grid.
350220 const mmq_args args = {
351- src0_d, src0->type , (const int *) src1_q8_1.get (), ids_dst. get (), expert_bounds. get () , dst_d,
221+ src0_d, src0->type , (const int *) src1_q8_1.ptr , ids_dst_dev, expert_bounds_dev , dst_d,
352222 ne00, ne01, ne_get_rows, s01, ne_get_rows, s1,
353223 ne02, ne02, s02, s12, s2,
354224 ne03, ne13, s03, s13, s3,
355- use_stream_k, ne12 };
225+ use_stream_k};
356226
357227 ggml_cuda_mul_mat_q_switch_type (ctx, args, stream);
358228}
@@ -392,7 +262,7 @@ void ggml_cuda_op_mul_mat_q(
392262 ne00, row_diff, src1_ncols, stride01, ne11, nrows_dst,
393263 1 , 1 , 0 , 0 , 0 ,
394264 1 , 1 , 0 , 0 , 0 ,
395- use_stream_k, src1_ncols };
265+ use_stream_k};
396266
397267 ggml_cuda_mul_mat_q_switch_type (ctx, args, stream);
398268
0 commit comments