33
44#include < vector>
55
6+ // To reduce shared memory use, store "it" and "iex_used" with 22/10 bits each.
7+ struct mmq_ids_helper_store {
8+ uint32_t data;
9+
10+ __device__ mmq_ids_helper_store (const uint32_t it, const uint32_t iex_used) {
11+ data = (it & 0x003FFFFF ) | (iex_used << 22 );
12+ }
13+
14+ __device__ uint32_t it () const {
15+ return data & 0x003FFFFF ;
16+ }
17+
18+ __device__ uint32_t iex_used () const {
19+ return data >> 22 ;
20+ }
21+ };
22+ static_assert (sizeof (mmq_ids_helper_store) == 4 , " unexpected size for mmq_ids_helper_store" );
623
724// Helper function for mul_mat_id, converts ids to a more convenient format.
825// ids_src1 describes how to permute the flattened column indices of src1 in order to get a compact src1 tensor sorted by expert.
@@ -17,9 +34,8 @@ static __global__ void mmq_ids_helper(
1734 const int n_expert_used = n_expert_used_template == 0 ? n_expert_used_var : n_expert_used_template;
1835 const int expert = blockIdx .x ;
1936
20- extern __shared__ int data_mmq_ids_helper[];
21- int * ids_src1_shared = data_mmq_ids_helper;
22- int * ids_dst_shared = ids_src1_shared + n_tokens;
37+ extern __shared__ char data_mmq_ids_helper[];
38+ mmq_ids_helper_store * store = (mmq_ids_helper_store *) data_mmq_ids_helper;
2339
2440 int nex_prev = 0 ; // Number of columns for experts with a lower index.
2541 int it_compact = 0 ; // Running index for the compact slice of this expert.
@@ -37,8 +53,7 @@ static __global__ void mmq_ids_helper(
3753 }
3854
3955 if (iex_used != -1 ) {
40- ids_src1_shared[it_compact] = it*sis1 + iex_used % nchannels_y;
41- ids_dst_shared[it_compact] = it*n_expert_used + iex_used;
56+ store[it_compact] = mmq_ids_helper_store (it, iex_used);
4257 }
4358
4459 if (warp_reduce_any<warp_size>(iex_used != -1 )) {
@@ -72,8 +87,7 @@ static __global__ void mmq_ids_helper(
7287 }
7388
7489 if (iex_used != -1 ) {
75- ids_src1_shared[it_compact + it_compact_add_lower] = it*sis1 + iex_used % nchannels_y;
76- ids_dst_shared[it_compact + it_compact_add_lower] = it*n_expert_used + iex_used;
90+ store[it_compact + it_compact_add_lower] = mmq_ids_helper_store (it, iex_used);
7791 }
7892
7993 // The thread with the highest index in the warp always has the sum over the whole warp, use it to increment all threads:
@@ -82,9 +96,12 @@ static __global__ void mmq_ids_helper(
8296 }
8397 nex_prev = warp_reduce_sum<warp_size>(nex_prev);
8498
85- for (int it = threadIdx .x ; it < it_compact; it += warp_size) {
86- ids_src1[nex_prev + it] = ids_src1_shared[it];
87- ids_dst [nex_prev + it] = ids_dst_shared [it];
99+ for (int itc = threadIdx .x ; itc < it_compact; itc += warp_size) {
100+ const mmq_ids_helper_store store_it = store[itc];
101+ const int it = store_it.it ();
102+ const int iex_used = store_it.iex_used ();
103+ ids_src1[nex_prev + itc] = it*sis1 + iex_used % nchannels_y;
104+ ids_dst [nex_prev + itc] = it*n_expert_used + iex_used;
88105 }
89106
90107 if (threadIdx .x != 0 ) {
@@ -104,14 +121,17 @@ template <int n_expert_used_template>
104121static void launch_mmq_ids_helper (
105122 const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds,
106123 const int n_experts, const int n_tokens, const int n_expert_used_var, const int nchannels_y, const int si1, const int sis1, cudaStream_t stream) {
124+ GGML_ASSERT (n_tokens < (1 << 22 ) && " too few bits in mmq_ids_helper_store" );
125+ GGML_ASSERT (n_expert_used_var < (1 << 10 ) && " too few bits in mmq_ids_helper_store" );
126+
107127 const int id = ggml_cuda_get_device ();
108128 const int warp_size = ggml_cuda_info ().devices [id].warp_size ;
109129 const size_t smpbo = ggml_cuda_info ().devices [id].smpbo ;
110130 CUDA_SET_SHARED_MEMORY_LIMIT (mmq_ids_helper<n_expert_used_template>, smpbo);
111131
112132 const dim3 num_blocks (n_experts, 1 , 1 );
113133 const dim3 block_size (warp_size, 1 , 1 );
114- const size_t nbytes_shared = 2 * n_tokens*sizeof (int );
134+ const size_t nbytes_shared = n_tokens*sizeof (mmq_ids_helper_store );
115135 mmq_ids_helper<n_expert_used_template><<<num_blocks, block_size, nbytes_shared, stream>>>
116136 (ids, ids_src1, ids_dst, expert_bounds, n_tokens, n_expert_used_var, nchannels_y, si1, sis1);
117137}
0 commit comments