55//  SPDX-License-Identifier: MIT
66// 
77#include  " argsort.cuh" 
8+ #include  " sumrows.cuh" 
89
910template <typename  T>
1011static  inline  __device__  void  ggml_cuda_swap (T & a, T & b) {
@@ -24,8 +25,8 @@ struct store {
2425    constexpr  static  bool  has_thresh = false ;
2526};
2627
27- template <ggml_sort_order order, typename  Store>
28- static  __global__  void  k_argsort_f32_i32 (const  float  * x, int  * dst, const  int  ncols, int  ncols_pad, Store s) {
28+ template <ggml_sort_order order, typename  Store,  typename   dst_t >
29+ static  __global__  void  k_argsort_f32_T (const  float  * x, dst_t  * dst, const  int  ncols, int  ncols_pad,  int  ntop , Store s) {
2930//         int min_experts, float thresh_experts) {
3031    //  bitonic sort
3132    int  col = threadIdx .x ;
@@ -72,27 +73,99 @@ static __global__ void k_argsort_f32_i32(const float * x, int * dst, const int n
7273    if  constexpr  (Store::has_thresh) {
7374        __syncthreads ();
7475        float  max_val = x_row[dst_row[0 ]];
75-         if  (col < ncols) {
76-             dst[row * ncols + col] = col < s.min_experts  || x_row[dst_row[col]] >= s.thresh_experts *max_val ? dst_row[col] : -1 ;
76+         if  (col < ntop) {
77+             if  constexpr  (std::is_same_v<dst_t , int >) {
78+                 dst[row * ntop + col] = col < s.min_experts  || x_row[dst_row[col]] >= s.thresh_experts *max_val ? dst_row[col] : -1 ;
79+             } else  {
80+                 dst[row * ntop + col] = col < s.min_experts  || x_row[dst_row[col]] >= s.thresh_experts *max_val ? x_row[dst_row[col]] : 0 .f ;
81+             }
7782        }
7883    } else  {
79-         if  (col < ncols) {
80-             dst[row * ncols + col] = dst_row[col];
84+         if  (col < ntop) {
85+             if  constexpr  (std::is_same_v<dst_t , int >) {
86+                 dst[row * ntop + col] = dst_row[col];
87+             } else  {
88+                 dst[row * ntop + col] = x_row[dst_row[col]];
89+             }
90+         }
91+     }
92+ }
93+ 
94+ template <ggml_sort_order order>
95+ static  __global__  void  k_topk_sum (const  float  * x, float  * dst, const  int  ncols, int  ncols_pad, int  n_top_k) {
96+     //  bitonic sort
97+     int  col = threadIdx .x ;
98+     int  row = blockIdx .y ;
99+ 
100+     if  (col >= ncols_pad) {
101+         return ;
102+     }
103+ 
104+     const  float  * x_row = x + row * ncols;
105+     extern  __shared__  int  dst_row[];
106+ 
107+     //  initialize indices
108+     dst_row[col] = col;
109+ 
110+     __syncthreads ();
111+ 
112+     for  (int  k = 2 ; k <= ncols_pad; k *= 2 ) {
113+         for  (int  j = k / 2 ; j > 0 ; j /= 2 ) {
114+             int  ixj = col ^ j;
115+             if  (ixj > col) {
116+                 if  ((col & k) == 0 ) {
117+                     if  (dst_row[col] >= ncols ||
118+                         (dst_row[ixj] < ncols && (order == GGML_SORT_ORDER_ASC ?
119+                             x_row[dst_row[col]] > x_row[dst_row[ixj]] :
120+                             x_row[dst_row[col]] < x_row[dst_row[ixj]]))
121+                     ) {
122+                         ggml_cuda_swap (dst_row[col], dst_row[ixj]);
123+                     }
124+                 } else  {
125+                     if  (dst_row[ixj] >= ncols ||
126+                         (dst_row[col] < ncols && (order == GGML_SORT_ORDER_ASC ?
127+                             x_row[dst_row[col]] < x_row[dst_row[ixj]] :
128+                             x_row[dst_row[col]] > x_row[dst_row[ixj]]))
129+                     ) {
130+                         ggml_cuda_swap (dst_row[col], dst_row[ixj]);
131+                     }
132+                 }
133+             }
134+             __syncthreads ();
135+         }
136+     }
137+ 
138+     float  val = col < n_top_k ? x_row[dst_row[col]] : 0 ;
139+     val = warp_reduce_sum (val);
140+     if  (blockDim .x  > WARP_SIZE) {
141+         __syncthreads ();
142+         float  * s_sum = (float  *)dst_row;
143+         const  int         warp_id = threadIdx .x  / WARP_SIZE;
144+         const  int         lane_id = threadIdx .x  % WARP_SIZE;
145+         if  (lane_id == 0 ) {
146+             s_sum[warp_id] = val;
81147        }
148+         __syncthreads ();
149+         val = 0 .0f ;
150+         if  (lane_id < (static_cast <int >(blockDim .x ) / WARP_SIZE)) {
151+             val = s_sum[lane_id];
152+         }
153+         val = warp_reduce_sum (val);
154+     }
155+ 
156+     if  (col == 0 ) {
157+         dst[row] = val;
158+     }
159+ }
160+ 
161+ static  __global__  void  k_apply_mask (float  * dst, const  int  * groups,
162+         const  int  n_top_groups, const  int  n_per_group, const  int  ncols) {
163+     int  row = blockIdx .y ;
164+     for  (int  col = threadIdx .x ; col < n_top_groups*n_per_group; col += blockDim .x ) {
165+         int  ig = groups[row*n_top_groups + col / n_per_group];
166+         int  ic = col % n_per_group;
167+         dst[row*ncols + ig*n_per_group + ic] = -INFINITY;
82168    }
83-     // if (min_experts >= 0 && min_experts < ncols && thresh_experts > 0) {
84-     //     __syncthreads();
85-     //     float max_val = x_row[dst_row[0]];
86-     //     if (col < ncols) {
87-     //         dst[row * ncols + col] = col < min_experts || x_row[dst_row[col]] >= thresh_experts*max_val ? dst_row[col] : -1;
88-     //     }
89-     // }
90-     // else {
91-     //     // copy the result to dst without the padding
92-     //     if (col < ncols) {
93-     //         dst[row * ncols + col] = dst_row[col];
94-     //     }
95-     // }
96169}
97170
98171static  int  next_power_of_2 (int  x) {
@@ -103,7 +176,8 @@ static int next_power_of_2(int x) {
103176    return  n;
104177}
105178
106- static  void  argsort_f32_i32_cuda (const  float  * x, int  * dst, const  int  ncols, const  int  nrows,
179+ template  <typename  dst_t >
180+ static  void  argsort_f32_T_cuda (const  float  * x, dst_t  * dst, const  int  ncols, const  int  nrows, int  ntop,
107181        ggml_sort_order order, int  min_experts, float  thresh_experts, cudaStream_t stream) {
108182    //  bitonic sort requires ncols to be power of 2
109183    const  int  ncols_pad = next_power_of_2 (ncols);
@@ -117,20 +191,18 @@ static void argsort_f32_i32_cuda(const float * x, int * dst, const int ncols, co
117191
118192    if  (order == GGML_SORT_ORDER_ASC) {
119193        if  (min_experts >= 0  && min_experts < ncols && thresh_experts > 0 ) {
120-             k_argsort_f32_i32 <GGML_SORT_ORDER_ASC, store_ser><<<block_nums, block_dims, shared_mem, stream>>> (x, dst, ncols, ncols_pad,
121-                     {min_experts, thresh_experts});
194+             k_argsort_f32_T <GGML_SORT_ORDER_ASC, store_ser><<<block_nums, block_dims, shared_mem, stream>>> (x, dst, ncols, ncols_pad,
195+                     ntop,  {min_experts, thresh_experts});
122196        } else  {
123-             k_argsort_f32_i32 <GGML_SORT_ORDER_ASC, store><<<block_nums, block_dims, shared_mem, stream>>> (x, dst, ncols, ncols_pad, {});
197+             k_argsort_f32_T <GGML_SORT_ORDER_ASC, store><<<block_nums, block_dims, shared_mem, stream>>> (x, dst, ncols, ncols_pad, ntop , {});
124198        }
125-         // k_argsort_f32_i32<GGML_SORT_ORDER_ASC><<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad, min_experts, thresh_experts);
126199    } else  if  (order == GGML_SORT_ORDER_DESC) {
127200        if  (min_experts >= 0  && min_experts < ncols && thresh_experts > 0 ) {
128-             k_argsort_f32_i32 <GGML_SORT_ORDER_DESC, store_ser><<<block_nums, block_dims, shared_mem, stream>>> (x, dst, ncols, ncols_pad,
129-                     {min_experts, thresh_experts});
201+             k_argsort_f32_T <GGML_SORT_ORDER_DESC, store_ser><<<block_nums, block_dims, shared_mem, stream>>> (x, dst, ncols, ncols_pad,
202+                     ntop,  {min_experts, thresh_experts});
130203        } else  {
131-             k_argsort_f32_i32 <GGML_SORT_ORDER_DESC, store><<<block_nums, block_dims, shared_mem, stream>>> (x, dst, ncols, ncols_pad, {});
204+             k_argsort_f32_T <GGML_SORT_ORDER_DESC, store><<<block_nums, block_dims, shared_mem, stream>>> (x, dst, ncols, ncols_pad, ntop , {});
132205        }
133-         // k_argsort_f32_i32<GGML_SORT_ORDER_DESC><<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad, min_experts, thresh_experts);
134206    } else  {
135207        GGML_ABORT (" fatal error"  );
136208    }
@@ -151,7 +223,7 @@ void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
151223
152224    enum  ggml_sort_order order = (enum  ggml_sort_order) dst->op_params [0 ];
153225
154-     argsort_f32_i32_cuda (src0_d, (int  *)dst_d, ncols, nrows, order, -1 , 0 .f , stream);
226+     argsort_f32_T_cuda (src0_d, (int  *)dst_d, ncols, nrows, ncols , order, -1 , 0 .f , stream);
155227}
156228
157229void  ggml_cuda_op_argsort_thresh (ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
@@ -171,5 +243,70 @@ void ggml_cuda_op_argsort_thresh(ggml_backend_cuda_context & ctx, ggml_tensor *
171243    float  thresh;
172244    memcpy (&thresh, dst->op_params  + 1 , sizeof (float ));
173245
174-     argsort_f32_i32_cuda (src0_d, (int  *)dst_d, ncols, nrows, GGML_SORT_ORDER_DESC, min_experts, thresh, stream);
246+     argsort_f32_T_cuda (src0_d, (int  *)dst_d, ncols, nrows, ncols, GGML_SORT_ORDER_DESC, min_experts, thresh, stream);
247+ }
248+ 
249+ static  void  ggml_cuda_op_topk_sum (ggml_backend_cuda_context & ctx, const  float  * src, float  * dst, int  ncols, int  nrows, int  n_top_k) {
250+ 
251+     GGML_ASSERT (n_top_k <= ncols);
252+ 
253+     const  int  ncols_pad = next_power_of_2 (ncols);
254+ 
255+     const  dim3  block_dims (ncols_pad, 1 , 1 );
256+     const  dim3  block_nums (1 , nrows, 1 );
257+     const  size_t  shared_mem = std::max (ncols_pad, WARP_SIZE) * sizeof (int );
258+     GGML_ASSERT (shared_mem <= ggml_cuda_info ().devices [ggml_cuda_get_device ()].smpb );
259+ 
260+     k_topk_sum<GGML_SORT_ORDER_DESC><<<block_nums, block_dims, shared_mem, ctx.stream()>>> (src, dst, ncols, ncols_pad, n_top_k);
261+ }
262+ 
263+ void  ggml_cuda_op_grouped_topk (ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
264+     auto  src = dst->src [0 ];
265+     GGML_ASSERT (dst->type  == GGML_TYPE_I32);
266+     GGML_ASSERT (src->type  == GGML_TYPE_F32);
267+     GGML_ASSERT (ggml_nrows (src) == ggml_nrows (dst));
268+ 
269+     auto  nrows = ggml_nrows (src);
270+ 
271+     int  n_groups     = dst->op_params [0 ];
272+     int  n_top_groups = dst->op_params [1 ];
273+     int  nk           = dst->op_params [2 ];
274+ 
275+     int  ne00 = src->ne [0 ];
276+     int  ne0  = dst->ne [0 ];
277+     GGML_ASSERT (ne0 <= ne00);
278+     GGML_ASSERT (ne00%n_groups == 0 );
279+     int  n_per_group = ne00/n_groups;
280+     GGML_ASSERT (nk <= n_per_group);
281+     GGML_ASSERT (n_top_groups < n_groups);
282+     int  n_discarded_groups = n_groups - n_top_groups;
283+ 
284+ #if  0 
285+     ggml_cuda_pool_alloc<float> sorted_group_scores(ctx.pool(), nk*nrows*n_groups);
286+     argsort_f32_T_cuda((const float *)src->data, sorted_group_scores.get(), n_per_group, nrows*n_groups, nk,
287+             GGML_SORT_ORDER_DESC, -1, 0.0f, ctx.stream());
288+     CUDA_CHECK(cudaGetLastError());
289+     ggml_cuda_pool_alloc<float> group_scores(ctx.pool(), nrows*n_groups);
290+     sum_rows_f32_cuda((const float *)sorted_group_scores.get(), group_scores.get(), nk, nrows*n_groups, ctx.stream());
291+     CUDA_CHECK(cudaGetLastError());
292+ #else 
293+     ggml_cuda_pool_alloc<float > group_scores (ctx.pool (), nrows*n_groups);
294+     ggml_cuda_op_topk_sum (ctx, (const  float  *)src->data , group_scores.get (), n_per_group, nrows*n_groups, nk);
295+     CUDA_CHECK (cudaGetLastError ());
296+ #endif 
297+ 
298+     ggml_cuda_pool_alloc<int > discarded_groups (ctx.pool (), nrows*n_discarded_groups);
299+     argsort_f32_T_cuda (group_scores.get (), discarded_groups.get (), n_groups, nrows, n_discarded_groups, GGML_SORT_ORDER_ASC, -1 , 0 .0f , ctx.stream ());
300+     CUDA_CHECK (cudaGetLastError ());
301+ 
302+     {
303+         const  dim3  block_dims (WARP_SIZE, 1 , 1 );
304+         const  dim3  block_nums (1 , nrows, 1 );
305+         cudaStream_t stream = ctx.stream ();
306+         k_apply_mask<<<block_nums, block_dims, 0 , ctx.stream()>>> ((float  *)src->data , discarded_groups.get (), n_discarded_groups, n_per_group, ne00);
307+         CUDA_CHECK (cudaGetLastError ());
308+     }
309+ 
310+     argsort_f32_T_cuda ((const  float  *)src->data , (int  *)dst->data , ne00, nrows, ne0, GGML_SORT_ORDER_DESC, -1 , 0 .0f , ctx.stream ());
311+ 
175312}
0 commit comments