@@ -100,6 +100,22 @@ static __global__ void mmq_ids_helper(
100100 expert_bounds[gridDim .x ] = nex_prev + it_compact;
101101}
102102
103+ template <int n_expert_used_template>
104+ static void launch_mmq_ids_helper (
105+ const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds,
106+ const int n_experts, const int n_tokens, const int n_expert_used_var, const int nchannels_y, const int si1, const int sis1, cudaStream_t stream) {
107+ const int id = ggml_cuda_get_device ();
108+ const int warp_size = ggml_cuda_info ().devices [id].warp_size ;
109+ const size_t smpbo = ggml_cuda_info ().devices [id].smpbo ;
110+ CUDA_SET_SHARED_MEMORY_LIMIT (mmq_ids_helper<n_expert_used_template>, smpbo);
111+
112+ const dim3 num_blocks (n_experts, 1 , 1 );
113+ const dim3 block_size (warp_size, 1 , 1 );
114+ const size_t nbytes_shared = 2 *n_tokens*sizeof (int );
115+ mmq_ids_helper<n_expert_used_template><<<num_blocks, block_size, nbytes_shared, stream>>>
116+ (ids, ids_src1, ids_dst, expert_bounds, n_tokens, n_expert_used_var, nchannels_y, si1, sis1);
117+ }
118+
103119static void ggml_cuda_mul_mat_q_switch_type (ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) {
104120 switch (args.type_x ) {
105121 case GGML_TYPE_Q4_0:
@@ -174,9 +190,7 @@ void ggml_cuda_mul_mat_q(
174190 GGML_TENSOR_BINARY_OP_LOCALS;
175191
176192 cudaStream_t stream = ctx.stream ();
177- const int id = ggml_cuda_get_device ();
178- const int cc = ggml_cuda_info ().devices [id].cc ;
179- const int warp_size = ggml_cuda_info ().devices [id].warp_size ;
193+ const int cc = ggml_cuda_info ().devices [ggml_cuda_get_device ()].cc ;
180194
181195 const size_t ts_src0 = ggml_type_size (src0->type );
182196 const size_t ts_src1 = ggml_type_size (src1->type );
@@ -258,46 +272,37 @@ void ggml_cuda_mul_mat_q(
258272 const int si1 = ids->nb [1 ] / ggml_element_size (ids);
259273 const int sis1 = nb12 / nb11;
260274
261- const dim3 num_blocks (ne02, 1 , 1 );
262- const dim3 block_size (warp_size, 1 , 1 );
263- const size_t nbytes_shared = 2 *ne12*sizeof (int );
264275 switch (n_expert_used) {
265276 case 2 :
266- mmq_ids_helper< 2 ><<<num_blocks, block_size, nbytes_shared, stream>>>
267- ((const int32_t *) ids->data , ids_src1.get (), ids_dst.get (), expert_bounds.get (),
268- ne12, n_expert_used, ne11, si1, sis1);
277+ launch_mmq_ids_helper< 2 > ((const int32_t *) ids->data , ids_src1.get (), ids_dst.get (), expert_bounds.get (),
278+ ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
269279 break ;
270280 case 4 :
271- mmq_ids_helper< 4 ><<<num_blocks, block_size, nbytes_shared, stream>>>
272- ((const int32_t *) ids->data , ids_src1.get (), ids_dst.get (), expert_bounds.get (),
273- ne12, n_expert_used, ne11, si1, sis1);
281+ launch_mmq_ids_helper< 4 > ((const int32_t *) ids->data , ids_src1.get (), ids_dst.get (), expert_bounds.get (),
282+ ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
274283 break ;
275284 case 6 :
276- mmq_ids_helper< 6 ><<<num_blocks, block_size, nbytes_shared, stream>>>
277- ((const int32_t *) ids->data , ids_src1.get (), ids_dst.get (), expert_bounds.get (),
278- ne12, n_expert_used, ne11, si1, sis1);
285+ launch_mmq_ids_helper< 6 > ((const int32_t *) ids->data , ids_src1.get (), ids_dst.get (), expert_bounds.get (),
286+ ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
279287 break ;
280288 case 8 :
281- mmq_ids_helper< 8 ><<<num_blocks, block_size, nbytes_shared, stream>>>
282- ((const int32_t *) ids->data , ids_src1.get (), ids_dst.get (), expert_bounds.get (),
283- ne12, n_expert_used, ne11, si1, sis1);
289+ launch_mmq_ids_helper< 8 > ((const int32_t *) ids->data , ids_src1.get (), ids_dst.get (), expert_bounds.get (),
290+ ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
284291 break ;
285292 case 16 :
286- mmq_ids_helper<16 ><<<num_blocks, block_size, nbytes_shared, stream>>>
287- ((const int32_t *) ids->data , ids_src1.get (), ids_dst.get (), expert_bounds.get (),
288- ne12, n_expert_used, ne11, si1, sis1);
293+ launch_mmq_ids_helper<16 > ((const int32_t *) ids->data , ids_src1.get (), ids_dst.get (), expert_bounds.get (),
294+ ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
289295 break ;
290296 case 32 :
291- mmq_ids_helper<32 ><<<num_blocks, block_size, nbytes_shared, stream>>>
292- ((const int32_t *) ids->data , ids_src1.get (), ids_dst.get (), expert_bounds.get (),
293- ne12, n_expert_used, ne11, si1, sis1);
297+ launch_mmq_ids_helper<32 > ((const int32_t *) ids->data , ids_src1.get (), ids_dst.get (), expert_bounds.get (),
298+ ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
294299 break ;
295300 default :
296- mmq_ids_helper<0 ><<<num_blocks, block_size, nbytes_shared, stream>>>
297- ((const int32_t *) ids->data , ids_src1.get (), ids_dst.get (), expert_bounds.get (),
298- ne12, n_expert_used, ne11, si1, sis1);
301+ launch_mmq_ids_helper< 0 > ((const int32_t *) ids->data , ids_src1.get (), ids_dst.get (), expert_bounds.get (),
302+ ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
299303 break ;
300304 }
305+ CUDA_CHECK (cudaGetLastError ());
301306 }
302307
303308 const size_t nbytes_src1_q8_1 = ne12*n_expert_used*ne10_padded * sizeof (block_q8_1)/QK8_1 +
0 commit comments