Skip to content

Commit a2f702a

Browse files
reduce shared memory use
1 parent 5724990 commit a2f702a

File tree

1 file changed

+31
-11
lines changed

1 file changed

+31
-11
lines changed

ggml/src/ggml-cuda/mmq.cu

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,23 @@
33

44
#include <vector>
55

6+
// To reduce shared memory use, store "it" and "iex_used" with 22/10 bits each.
7+
struct mmq_ids_helper_store {
8+
uint32_t data;
9+
10+
__device__ mmq_ids_helper_store(const uint32_t it, const uint32_t iex_used) {
11+
data = (it & 0x003FFFFF) | (iex_used << 22);
12+
}
13+
14+
__device__ uint32_t it() const {
15+
return data & 0x003FFFFF;
16+
}
17+
18+
__device__ uint32_t iex_used() const {
19+
return data >> 22;
20+
}
21+
};
22+
static_assert(sizeof(mmq_ids_helper_store) == 4, "unexpected size for mmq_ids_helper_store");
623

724
// Helper function for mul_mat_id, converts ids to a more convenient format.
825
// ids_src1 describes how to permute the flattened column indices of src1 in order to get a compact src1 tensor sorted by expert.
@@ -17,9 +34,8 @@ static __global__ void mmq_ids_helper(
1734
const int n_expert_used = n_expert_used_template == 0 ? n_expert_used_var : n_expert_used_template;
1835
const int expert = blockIdx.x;
1936

20-
extern __shared__ int data_mmq_ids_helper[];
21-
int * ids_src1_shared = data_mmq_ids_helper;
22-
int * ids_dst_shared = ids_src1_shared + n_tokens;
37+
extern __shared__ char data_mmq_ids_helper[];
38+
mmq_ids_helper_store * store = (mmq_ids_helper_store *) data_mmq_ids_helper;
2339

2440
int nex_prev = 0; // Number of columns for experts with a lower index.
2541
int it_compact = 0; // Running index for the compact slice of this expert.
@@ -37,8 +53,7 @@ static __global__ void mmq_ids_helper(
3753
}
3854

3955
if (iex_used != -1) {
40-
ids_src1_shared[it_compact] = it*sis1 + iex_used % nchannels_y;
41-
ids_dst_shared[it_compact] = it*n_expert_used + iex_used;
56+
store[it_compact] = mmq_ids_helper_store(it, iex_used);
4257
}
4358

4459
if (warp_reduce_any<warp_size>(iex_used != -1)) {
@@ -72,8 +87,7 @@ static __global__ void mmq_ids_helper(
7287
}
7388

7489
if (iex_used != -1) {
75-
ids_src1_shared[it_compact + it_compact_add_lower] = it*sis1 + iex_used % nchannels_y;
76-
ids_dst_shared[it_compact + it_compact_add_lower] = it*n_expert_used + iex_used;
90+
store[it_compact + it_compact_add_lower] = mmq_ids_helper_store(it, iex_used);
7791
}
7892

7993
// The thread with the highest index in the warp always has the sum over the whole warp, use it to increment all threads:
@@ -82,9 +96,12 @@ static __global__ void mmq_ids_helper(
8296
}
8397
nex_prev = warp_reduce_sum<warp_size>(nex_prev);
8498

85-
for (int it = threadIdx.x; it < it_compact; it += warp_size) {
86-
ids_src1[nex_prev + it] = ids_src1_shared[it];
87-
ids_dst [nex_prev + it] = ids_dst_shared [it];
99+
for (int itc = threadIdx.x; itc < it_compact; itc += warp_size) {
100+
const mmq_ids_helper_store store_it = store[itc];
101+
const int it = store_it.it();
102+
const int iex_used = store_it.iex_used();
103+
ids_src1[nex_prev + itc] = it*sis1 + iex_used % nchannels_y;
104+
ids_dst [nex_prev + itc] = it*n_expert_used + iex_used;
88105
}
89106

90107
if (threadIdx.x != 0) {
@@ -104,14 +121,17 @@ template <int n_expert_used_template>
104121
static void launch_mmq_ids_helper(
105122
const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds,
106123
const int n_experts, const int n_tokens, const int n_expert_used_var, const int nchannels_y, const int si1, const int sis1, cudaStream_t stream) {
124+
GGML_ASSERT(n_tokens < (1 << 22) && "too few bits in mmq_ids_helper_store");
125+
GGML_ASSERT(n_expert_used_var < (1 << 10) && "too few bits in mmq_ids_helper_store");
126+
107127
const int id = ggml_cuda_get_device();
108128
const int warp_size = ggml_cuda_info().devices[id].warp_size;
109129
const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
110130
CUDA_SET_SHARED_MEMORY_LIMIT(mmq_ids_helper<n_expert_used_template>, smpbo);
111131

112132
const dim3 num_blocks(n_experts, 1, 1);
113133
const dim3 block_size(warp_size, 1, 1);
114-
const size_t nbytes_shared = 2*n_tokens*sizeof(int);
134+
const size_t nbytes_shared = n_tokens*sizeof(mmq_ids_helper_store);
115135
mmq_ids_helper<n_expert_used_template><<<num_blocks, block_size, nbytes_shared, stream>>>
116136
(ids, ids_src1, ids_dst, expert_bounds, n_tokens, n_expert_used_var, nchannels_y, si1, sis1);
117137
}

0 commit comments

Comments
 (0)