55// SPDX-License-Identifier: MIT
66//
77#include " argsort.cuh"
8+ #include " sumrows.cuh"
89
910template <typename T>
1011static inline __device__ void ggml_cuda_swap (T & a, T & b) {
@@ -24,8 +25,8 @@ struct store {
2425 constexpr static bool has_thresh = false ;
2526};
2627
27- template <ggml_sort_order order, typename Store>
28- static __global__ void k_argsort_f32_i32 (const float * x, int * dst, const int ncols, int ncols_pad, Store s) {
28+ template <ggml_sort_order order, typename Store, typename dst_t >
29+ static __global__ void k_argsort_f32_T (const float * x, dst_t * dst, const int ncols, int ncols_pad, int ntop , Store s) {
2930// int min_experts, float thresh_experts) {
3031 // bitonic sort
3132 int col = threadIdx .x ;
@@ -72,27 +73,99 @@ static __global__ void k_argsort_f32_i32(const float * x, int * dst, const int n
7273 if constexpr (Store::has_thresh) {
7374 __syncthreads ();
7475 float max_val = x_row[dst_row[0 ]];
75- if (col < ncols) {
76- dst[row * ncols + col] = col < s.min_experts || x_row[dst_row[col]] >= s.thresh_experts *max_val ? dst_row[col] : -1 ;
76+ if (col < ntop) {
77+ if constexpr (std::is_same_v<dst_t , int >) {
78+ dst[row * ntop + col] = col < s.min_experts || x_row[dst_row[col]] >= s.thresh_experts *max_val ? dst_row[col] : -1 ;
79+ } else {
80+ dst[row * ntop + col] = col < s.min_experts || x_row[dst_row[col]] >= s.thresh_experts *max_val ? x_row[dst_row[col]] : 0 .f ;
81+ }
7782 }
7883 } else {
79- if (col < ncols) {
80- dst[row * ncols + col] = dst_row[col];
84+ if (col < ntop) {
85+ if constexpr (std::is_same_v<dst_t , int >) {
86+ dst[row * ntop + col] = dst_row[col];
87+ } else {
88+ dst[row * ntop + col] = x_row[dst_row[col]];
89+ }
90+ }
91+ }
92+ }
93+
94+ template <ggml_sort_order order>
95+ static __global__ void k_topk_sum (const float * x, float * dst, const int ncols, int ncols_pad, int n_top_k) {
96+ // bitonic sort
97+ int col = threadIdx .x ;
98+ int row = blockIdx .y ;
99+
100+ if (col >= ncols_pad) {
101+ return ;
102+ }
103+
104+ const float * x_row = x + row * ncols;
105+ extern __shared__ int dst_row[];
106+
107+ // initialize indices
108+ dst_row[col] = col;
109+
110+ __syncthreads ();
111+
112+ for (int k = 2 ; k <= ncols_pad; k *= 2 ) {
113+ for (int j = k / 2 ; j > 0 ; j /= 2 ) {
114+ int ixj = col ^ j;
115+ if (ixj > col) {
116+ if ((col & k) == 0 ) {
117+ if (dst_row[col] >= ncols ||
118+ (dst_row[ixj] < ncols && (order == GGML_SORT_ORDER_ASC ?
119+ x_row[dst_row[col]] > x_row[dst_row[ixj]] :
120+ x_row[dst_row[col]] < x_row[dst_row[ixj]]))
121+ ) {
122+ ggml_cuda_swap (dst_row[col], dst_row[ixj]);
123+ }
124+ } else {
125+ if (dst_row[ixj] >= ncols ||
126+ (dst_row[col] < ncols && (order == GGML_SORT_ORDER_ASC ?
127+ x_row[dst_row[col]] < x_row[dst_row[ixj]] :
128+ x_row[dst_row[col]] > x_row[dst_row[ixj]]))
129+ ) {
130+ ggml_cuda_swap (dst_row[col], dst_row[ixj]);
131+ }
132+ }
133+ }
134+ __syncthreads ();
135+ }
136+ }
137+
138+ float val = col < n_top_k ? x_row[dst_row[col]] : 0 ;
139+ val = warp_reduce_sum (val);
140+ if (blockDim .x > WARP_SIZE) {
141+ __syncthreads ();
142+ float * s_sum = (float *)dst_row;
143+ const int warp_id = threadIdx .x / WARP_SIZE;
144+ const int lane_id = threadIdx .x % WARP_SIZE;
145+ if (lane_id == 0 ) {
146+ s_sum[warp_id] = val;
81147 }
148+ __syncthreads ();
149+ val = 0 .0f ;
150+ if (lane_id < (static_cast <int >(blockDim .x ) / WARP_SIZE)) {
151+ val = s_sum[lane_id];
152+ }
153+ val = warp_reduce_sum (val);
154+ }
155+
156+ if (col == 0 ) {
157+ dst[row] = val;
158+ }
159+ }
160+
161+ static __global__ void k_apply_mask (float * dst, const int * groups,
162+ const int n_top_groups, const int n_per_group, const int ncols) {
163+ int row = blockIdx .y ;
164+ for (int col = threadIdx .x ; col < n_top_groups*n_per_group; col += blockDim .x ) {
165+ int ig = groups[row*n_top_groups + col / n_per_group];
166+ int ic = col % n_per_group;
167+ dst[row*ncols + ig*n_per_group + ic] = -INFINITY;
82168 }
83- // if (min_experts >= 0 && min_experts < ncols && thresh_experts > 0) {
84- // __syncthreads();
85- // float max_val = x_row[dst_row[0]];
86- // if (col < ncols) {
87- // dst[row * ncols + col] = col < min_experts || x_row[dst_row[col]] >= thresh_experts*max_val ? dst_row[col] : -1;
88- // }
89- // }
90- // else {
91- // // copy the result to dst without the padding
92- // if (col < ncols) {
93- // dst[row * ncols + col] = dst_row[col];
94- // }
95- // }
96169}
97170
98171static int next_power_of_2 (int x) {
@@ -103,7 +176,8 @@ static int next_power_of_2(int x) {
103176 return n;
104177}
105178
106- static void argsort_f32_i32_cuda (const float * x, int * dst, const int ncols, const int nrows,
179+ template <typename dst_t >
180+ static void argsort_f32_T_cuda (const float * x, dst_t * dst, const int ncols, const int nrows, int ntop,
107181 ggml_sort_order order, int min_experts, float thresh_experts, cudaStream_t stream) {
108182 // bitonic sort requires ncols to be power of 2
109183 const int ncols_pad = next_power_of_2 (ncols);
@@ -117,20 +191,18 @@ static void argsort_f32_i32_cuda(const float * x, int * dst, const int ncols, co
117191
118192 if (order == GGML_SORT_ORDER_ASC) {
119193 if (min_experts >= 0 && min_experts < ncols && thresh_experts > 0 ) {
120- k_argsort_f32_i32 <GGML_SORT_ORDER_ASC, store_ser><<<block_nums, block_dims, shared_mem, stream>>> (x, dst, ncols, ncols_pad,
121- {min_experts, thresh_experts});
194+ k_argsort_f32_T <GGML_SORT_ORDER_ASC, store_ser><<<block_nums, block_dims, shared_mem, stream>>> (x, dst, ncols, ncols_pad,
195+ ntop, {min_experts, thresh_experts});
122196 } else {
123- k_argsort_f32_i32 <GGML_SORT_ORDER_ASC, store><<<block_nums, block_dims, shared_mem, stream>>> (x, dst, ncols, ncols_pad, {});
197+ k_argsort_f32_T <GGML_SORT_ORDER_ASC, store><<<block_nums, block_dims, shared_mem, stream>>> (x, dst, ncols, ncols_pad, ntop , {});
124198 }
125- // k_argsort_f32_i32<GGML_SORT_ORDER_ASC><<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad, min_experts, thresh_experts);
126199 } else if (order == GGML_SORT_ORDER_DESC) {
127200 if (min_experts >= 0 && min_experts < ncols && thresh_experts > 0 ) {
128- k_argsort_f32_i32 <GGML_SORT_ORDER_DESC, store_ser><<<block_nums, block_dims, shared_mem, stream>>> (x, dst, ncols, ncols_pad,
129- {min_experts, thresh_experts});
201+ k_argsort_f32_T <GGML_SORT_ORDER_DESC, store_ser><<<block_nums, block_dims, shared_mem, stream>>> (x, dst, ncols, ncols_pad,
202+ ntop, {min_experts, thresh_experts});
130203 } else {
131- k_argsort_f32_i32 <GGML_SORT_ORDER_DESC, store><<<block_nums, block_dims, shared_mem, stream>>> (x, dst, ncols, ncols_pad, {});
204+ k_argsort_f32_T <GGML_SORT_ORDER_DESC, store><<<block_nums, block_dims, shared_mem, stream>>> (x, dst, ncols, ncols_pad, ntop , {});
132205 }
133- // k_argsort_f32_i32<GGML_SORT_ORDER_DESC><<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad, min_experts, thresh_experts);
134206 } else {
135207 GGML_ABORT (" fatal error" );
136208 }
@@ -151,7 +223,7 @@ void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
151223
152224 enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params [0 ];
153225
154- argsort_f32_i32_cuda (src0_d, (int *)dst_d, ncols, nrows, order, -1 , 0 .f , stream);
226+ argsort_f32_T_cuda (src0_d, (int *)dst_d, ncols, nrows, ncols , order, -1 , 0 .f , stream);
155227}
156228
157229void ggml_cuda_op_argsort_thresh (ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
@@ -171,5 +243,70 @@ void ggml_cuda_op_argsort_thresh(ggml_backend_cuda_context & ctx, ggml_tensor *
171243 float thresh;
172244 memcpy (&thresh, dst->op_params + 1 , sizeof (float ));
173245
174- argsort_f32_i32_cuda (src0_d, (int *)dst_d, ncols, nrows, GGML_SORT_ORDER_DESC, min_experts, thresh, stream);
246+ argsort_f32_T_cuda (src0_d, (int *)dst_d, ncols, nrows, ncols, GGML_SORT_ORDER_DESC, min_experts, thresh, stream);
247+ }
248+
249+ static void ggml_cuda_op_topk_sum (ggml_backend_cuda_context & ctx, const float * src, float * dst, int ncols, int nrows, int n_top_k) {
250+
251+ GGML_ASSERT (n_top_k <= ncols);
252+
253+ const int ncols_pad = next_power_of_2 (ncols);
254+
255+ const dim3 block_dims (ncols_pad, 1 , 1 );
256+ const dim3 block_nums (1 , nrows, 1 );
257+ const size_t shared_mem = std::max (ncols_pad, WARP_SIZE) * sizeof (int );
258+ GGML_ASSERT (shared_mem <= ggml_cuda_info ().devices [ggml_cuda_get_device ()].smpb );
259+
260+ k_topk_sum<GGML_SORT_ORDER_DESC><<<block_nums, block_dims, shared_mem, ctx.stream()>>> (src, dst, ncols, ncols_pad, n_top_k);
261+ }
262+
263+ void ggml_cuda_op_grouped_topk (ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
264+ auto src = dst->src [0 ];
265+ GGML_ASSERT (dst->type == GGML_TYPE_I32);
266+ GGML_ASSERT (src->type == GGML_TYPE_F32);
267+ GGML_ASSERT (ggml_nrows (src) == ggml_nrows (dst));
268+
269+ auto nrows = ggml_nrows (src);
270+
271+ int n_groups = dst->op_params [0 ];
272+ int n_top_groups = dst->op_params [1 ];
273+ int nk = dst->op_params [2 ];
274+
275+ int ne00 = src->ne [0 ];
276+ int ne0 = dst->ne [0 ];
277+ GGML_ASSERT (ne0 <= ne00);
278+ GGML_ASSERT (ne00%n_groups == 0 );
279+ int n_per_group = ne00/n_groups;
280+ GGML_ASSERT (nk <= n_per_group);
281+ GGML_ASSERT (n_top_groups < n_groups);
282+ int n_discarded_groups = n_groups - n_top_groups;
283+
284+ #if 0
285+ ggml_cuda_pool_alloc<float> sorted_group_scores(ctx.pool(), nk*nrows*n_groups);
286+ argsort_f32_T_cuda((const float *)src->data, sorted_group_scores.get(), n_per_group, nrows*n_groups, nk,
287+ GGML_SORT_ORDER_DESC, -1, 0.0f, ctx.stream());
288+ CUDA_CHECK(cudaGetLastError());
289+ ggml_cuda_pool_alloc<float> group_scores(ctx.pool(), nrows*n_groups);
290+ sum_rows_f32_cuda((const float *)sorted_group_scores.get(), group_scores.get(), nk, nrows*n_groups, ctx.stream());
291+ CUDA_CHECK(cudaGetLastError());
292+ #else
293+ ggml_cuda_pool_alloc<float > group_scores (ctx.pool (), nrows*n_groups);
294+ ggml_cuda_op_topk_sum (ctx, (const float *)src->data , group_scores.get (), n_per_group, nrows*n_groups, nk);
295+ CUDA_CHECK (cudaGetLastError ());
296+ #endif
297+
298+ ggml_cuda_pool_alloc<int > discarded_groups (ctx.pool (), nrows*n_discarded_groups);
299+ argsort_f32_T_cuda (group_scores.get (), discarded_groups.get (), n_groups, nrows, n_discarded_groups, GGML_SORT_ORDER_ASC, -1 , 0 .0f , ctx.stream ());
300+ CUDA_CHECK (cudaGetLastError ());
301+
302+ {
303+ const dim3 block_dims (WARP_SIZE, 1 , 1 );
304+ const dim3 block_nums (1 , nrows, 1 );
305+ cudaStream_t stream = ctx.stream ();
306+ k_apply_mask<<<block_nums, block_dims, 0 , ctx.stream()>>> ((float *)src->data , discarded_groups.get (), n_discarded_groups, n_per_group, ne00);
307+ CUDA_CHECK (cudaGetLastError ());
308+ }
309+
310+ argsort_f32_T_cuda ((const float *)src->data , (int *)dst->data , ne00, nrows, ne0, GGML_SORT_ORDER_DESC, -1 , 0 .0f , ctx.stream ());
311+
175312}
0 commit comments