Skip to content

Commit 5724990

Browse files
raise shared memory limit
1 parent e7b884d commit 5724990

File tree

1 file changed

+32
-27
lines changed

1 file changed

+32
-27
lines changed

ggml/src/ggml-cuda/mmq.cu

Lines changed: 32 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,22 @@ static __global__ void mmq_ids_helper(
100100
expert_bounds[gridDim.x] = nex_prev + it_compact;
101101
}
102102

103+
template <int n_expert_used_template>
104+
static void launch_mmq_ids_helper(
105+
const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds,
106+
const int n_experts, const int n_tokens, const int n_expert_used_var, const int nchannels_y, const int si1, const int sis1, cudaStream_t stream) {
107+
const int id = ggml_cuda_get_device();
108+
const int warp_size = ggml_cuda_info().devices[id].warp_size;
109+
const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
110+
CUDA_SET_SHARED_MEMORY_LIMIT(mmq_ids_helper<n_expert_used_template>, smpbo);
111+
112+
const dim3 num_blocks(n_experts, 1, 1);
113+
const dim3 block_size(warp_size, 1, 1);
114+
const size_t nbytes_shared = 2*n_tokens*sizeof(int);
115+
mmq_ids_helper<n_expert_used_template><<<num_blocks, block_size, nbytes_shared, stream>>>
116+
(ids, ids_src1, ids_dst, expert_bounds, n_tokens, n_expert_used_var, nchannels_y, si1, sis1);
117+
}
118+
103119
static void ggml_cuda_mul_mat_q_switch_type(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) {
104120
switch (args.type_x) {
105121
case GGML_TYPE_Q4_0:
@@ -174,9 +190,7 @@ void ggml_cuda_mul_mat_q(
174190
GGML_TENSOR_BINARY_OP_LOCALS;
175191

176192
cudaStream_t stream = ctx.stream();
177-
const int id = ggml_cuda_get_device();
178-
const int cc = ggml_cuda_info().devices[id].cc;
179-
const int warp_size = ggml_cuda_info().devices[id].warp_size;
193+
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
180194

181195
const size_t ts_src0 = ggml_type_size(src0->type);
182196
const size_t ts_src1 = ggml_type_size(src1->type);
@@ -258,46 +272,37 @@ void ggml_cuda_mul_mat_q(
258272
const int si1 = ids->nb[1] / ggml_element_size(ids);
259273
const int sis1 = nb12 / nb11;
260274

261-
const dim3 num_blocks(ne02, 1, 1);
262-
const dim3 block_size(warp_size, 1, 1);
263-
const size_t nbytes_shared = 2*ne12*sizeof(int);
264275
switch (n_expert_used) {
265276
case 2:
266-
mmq_ids_helper< 2><<<num_blocks, block_size, nbytes_shared, stream>>>
267-
((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
268-
ne12, n_expert_used, ne11, si1, sis1);
277+
launch_mmq_ids_helper< 2> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
278+
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
269279
break;
270280
case 4:
271-
mmq_ids_helper< 4><<<num_blocks, block_size, nbytes_shared, stream>>>
272-
((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
273-
ne12, n_expert_used, ne11, si1, sis1);
281+
launch_mmq_ids_helper< 4> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
282+
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
274283
break;
275284
case 6:
276-
mmq_ids_helper< 6><<<num_blocks, block_size, nbytes_shared, stream>>>
277-
((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
278-
ne12, n_expert_used, ne11, si1, sis1);
285+
launch_mmq_ids_helper< 6> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
286+
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
279287
break;
280288
case 8:
281-
mmq_ids_helper< 8><<<num_blocks, block_size, nbytes_shared, stream>>>
282-
((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
283-
ne12, n_expert_used, ne11, si1, sis1);
289+
launch_mmq_ids_helper< 8> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
290+
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
284291
break;
285292
case 16:
286-
mmq_ids_helper<16><<<num_blocks, block_size, nbytes_shared, stream>>>
287-
((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
288-
ne12, n_expert_used, ne11, si1, sis1);
293+
launch_mmq_ids_helper<16> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
294+
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
289295
break;
290296
case 32:
291-
mmq_ids_helper<32><<<num_blocks, block_size, nbytes_shared, stream>>>
292-
((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
293-
ne12, n_expert_used, ne11, si1, sis1);
297+
launch_mmq_ids_helper<32> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
298+
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
294299
break;
295300
default:
296-
mmq_ids_helper<0><<<num_blocks, block_size, nbytes_shared, stream>>>
297-
((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
298-
ne12, n_expert_used, ne11, si1, sis1);
301+
launch_mmq_ids_helper< 0> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
302+
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
299303
break;
300304
}
305+
CUDA_CHECK(cudaGetLastError());
301306
}
302307

303308
const size_t nbytes_src1_q8_1 = ne12*n_expert_used*ne10_padded * sizeof(block_q8_1)/QK8_1 +

0 commit comments

Comments
 (0)