@@ -252,7 +252,7 @@ static __global__ void k_topk_sum(const float * x, const float * bias, float * x
252252
253253static __global__ void k_apply_mask (float * dst, const int * groups,
254254 const int n_top_groups, const int n_per_group, const int ncols) {
255- int row = blockIdx .y ;
255+ int row = blockIdx .x ;
256256 for (int col = threadIdx .x ; col < n_top_groups*n_per_group; col += blockDim .x ) {
257257 int ig = groups[row*n_top_groups + col / n_per_group];
258258 int ic = col % n_per_group;
@@ -463,7 +463,7 @@ void ggml_cuda_op_grouped_topk(ggml_backend_cuda_context & ctx, ggml_tensor * ds
463463
464464 {
465465 const dim3 block_dims (WARP_SIZE, 1 , 1 );
466- const dim3 block_nums (1 , nrows , 1 );
466+ const dim3 block_nums (nrows, 1 , 1 );
467467 cudaStream_t stream = ctx.stream ();
468468 k_apply_mask<<<block_nums, block_dims, 0 , ctx.stream()>>> ((float *)src->data , discarded_groups.get (), n_discarded_groups, n_per_group, ne00);
469469 CUDA_CHECK (cudaGetLastError ());
@@ -508,7 +508,7 @@ void cuda_bailingmoev2_experts(ggml_backend_cuda_context & ctx, ggml_tensor * ds
508508
509509 {
510510 const dim3 block_dims (WARP_SIZE, 1 , 1 );
511- const dim3 block_nums (1 , nrows , 1 );
511+ const dim3 block_nums (nrows, 1 , 1 );
512512 k_apply_mask<<<block_nums, block_dims, 0 , ctx.stream()>>> ((float *)topk_src->data , discarded_groups.get (), n_discarded_groups, n_per_group, ne00);
513513 CUDA_CHECK (cudaGetLastError ());
514514 }
0 commit comments