Skip to content

Commit fa54b5b

Browse files
authored
Merge branch 'ikawrakow:main' into main
2 parents a2dd06c + 747f411 commit fa54b5b

File tree

5 files changed

+176
-31
lines changed

5 files changed

+176
-31
lines changed

ggml/src/ggml-cuda.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3323,6 +3323,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
33233323
case GGML_OP_ARGSORT_THRESH:
33243324
ggml_cuda_op_argsort_thresh(ctx, dst);
33253325
break;
3326+
case GGML_OP_GROUPED_TOPK:
3327+
ggml_cuda_op_grouped_topk(ctx, dst);
3328+
break;
33263329
case GGML_OP_FLASH_ATTN_EXT:
33273330
ggml_cuda_flash_attn_ext(ctx, dst);
33283331
break;
@@ -4332,6 +4335,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
43324335
case GGML_OP_SUM_ROWS:
43334336
case GGML_OP_ARGSORT:
43344337
case GGML_OP_ARGSORT_THRESH:
4338+
case GGML_OP_GROUPED_TOPK:
43354339
case GGML_OP_ACC:
43364340
case GGML_OP_GROUP_NORM:
43374341
case GGML_OP_UPSCALE:

ggml/src/ggml-cuda/argsort.cu

Lines changed: 167 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
// SPDX-License-Identifier: MIT
66
//
77
#include "argsort.cuh"
8+
#include "sumrows.cuh"
89

910
template<typename T>
1011
static inline __device__ void ggml_cuda_swap(T & a, T & b) {
@@ -24,8 +25,8 @@ struct store {
2425
constexpr static bool has_thresh = false;
2526
};
2627

27-
template<ggml_sort_order order, typename Store>
28-
static __global__ void k_argsort_f32_i32(const float * x, int * dst, const int ncols, int ncols_pad, Store s) {
28+
template<ggml_sort_order order, typename Store, typename dst_t>
29+
static __global__ void k_argsort_f32_T(const float * x, dst_t * dst, const int ncols, int ncols_pad, int ntop, Store s) {
2930
// int min_experts, float thresh_experts) {
3031
// bitonic sort
3132
int col = threadIdx.x;
@@ -72,27 +73,99 @@ static __global__ void k_argsort_f32_i32(const float * x, int * dst, const int n
7273
if constexpr (Store::has_thresh) {
7374
__syncthreads();
7475
float max_val = x_row[dst_row[0]];
75-
if (col < ncols) {
76-
dst[row * ncols + col] = col < s.min_experts || x_row[dst_row[col]] >= s.thresh_experts*max_val ? dst_row[col] : -1;
76+
if (col < ntop) {
77+
if constexpr (std::is_same_v<dst_t, int>) {
78+
dst[row * ntop + col] = col < s.min_experts || x_row[dst_row[col]] >= s.thresh_experts*max_val ? dst_row[col] : -1;
79+
} else {
80+
dst[row * ntop + col] = col < s.min_experts || x_row[dst_row[col]] >= s.thresh_experts*max_val ? x_row[dst_row[col]] : 0.f;
81+
}
7782
}
7883
} else {
79-
if (col < ncols) {
80-
dst[row * ncols + col] = dst_row[col];
84+
if (col < ntop) {
85+
if constexpr (std::is_same_v<dst_t, int>) {
86+
dst[row * ntop + col] = dst_row[col];
87+
} else {
88+
dst[row * ntop + col] = x_row[dst_row[col]];
89+
}
90+
}
91+
}
92+
}
93+
94+
template<ggml_sort_order order>
95+
static __global__ void k_topk_sum(const float * x, float * dst, const int ncols, int ncols_pad, int n_top_k) {
96+
// bitonic sort
97+
int col = threadIdx.x;
98+
int row = blockIdx.y;
99+
100+
if (col >= ncols_pad) {
101+
return;
102+
}
103+
104+
const float * x_row = x + row * ncols;
105+
extern __shared__ int dst_row[];
106+
107+
// initialize indices
108+
dst_row[col] = col;
109+
110+
__syncthreads();
111+
112+
for (int k = 2; k <= ncols_pad; k *= 2) {
113+
for (int j = k / 2; j > 0; j /= 2) {
114+
int ixj = col ^ j;
115+
if (ixj > col) {
116+
if ((col & k) == 0) {
117+
if (dst_row[col] >= ncols ||
118+
(dst_row[ixj] < ncols && (order == GGML_SORT_ORDER_ASC ?
119+
x_row[dst_row[col]] > x_row[dst_row[ixj]] :
120+
x_row[dst_row[col]] < x_row[dst_row[ixj]]))
121+
) {
122+
ggml_cuda_swap(dst_row[col], dst_row[ixj]);
123+
}
124+
} else {
125+
if (dst_row[ixj] >= ncols ||
126+
(dst_row[col] < ncols && (order == GGML_SORT_ORDER_ASC ?
127+
x_row[dst_row[col]] < x_row[dst_row[ixj]] :
128+
x_row[dst_row[col]] > x_row[dst_row[ixj]]))
129+
) {
130+
ggml_cuda_swap(dst_row[col], dst_row[ixj]);
131+
}
132+
}
133+
}
134+
__syncthreads();
135+
}
136+
}
137+
138+
float val = col < n_top_k ? x_row[dst_row[col]] : 0;
139+
val = warp_reduce_sum(val);
140+
if (blockDim.x > WARP_SIZE) {
141+
__syncthreads();
142+
float * s_sum = (float *)dst_row;
143+
const int warp_id = threadIdx.x / WARP_SIZE;
144+
const int lane_id = threadIdx.x % WARP_SIZE;
145+
if (lane_id == 0) {
146+
s_sum[warp_id] = val;
81147
}
148+
__syncthreads();
149+
val = 0.0f;
150+
if (lane_id < (static_cast<int>(blockDim.x) / WARP_SIZE)) {
151+
val = s_sum[lane_id];
152+
}
153+
val = warp_reduce_sum(val);
154+
}
155+
156+
if (col == 0) {
157+
dst[row] = val;
158+
}
159+
}
160+
161+
static __global__ void k_apply_mask(float * dst, const int * groups,
162+
const int n_top_groups, const int n_per_group, const int ncols) {
163+
int row = blockIdx.y;
164+
for (int col = threadIdx.x; col < n_top_groups*n_per_group; col += blockDim.x) {
165+
int ig = groups[row*n_top_groups + col / n_per_group];
166+
int ic = col % n_per_group;
167+
dst[row*ncols + ig*n_per_group + ic] = -INFINITY;
82168
}
83-
//if (min_experts >= 0 && min_experts < ncols && thresh_experts > 0) {
84-
// __syncthreads();
85-
// float max_val = x_row[dst_row[0]];
86-
// if (col < ncols) {
87-
// dst[row * ncols + col] = col < min_experts || x_row[dst_row[col]] >= thresh_experts*max_val ? dst_row[col] : -1;
88-
// }
89-
//}
90-
//else {
91-
// // copy the result to dst without the padding
92-
// if (col < ncols) {
93-
// dst[row * ncols + col] = dst_row[col];
94-
// }
95-
//}
96169
}
97170

98171
static int next_power_of_2(int x) {
@@ -103,7 +176,8 @@ static int next_power_of_2(int x) {
103176
return n;
104177
}
105178

106-
static void argsort_f32_i32_cuda(const float * x, int * dst, const int ncols, const int nrows,
179+
template <typename dst_t>
180+
static void argsort_f32_T_cuda(const float * x, dst_t * dst, const int ncols, const int nrows, int ntop,
107181
ggml_sort_order order, int min_experts, float thresh_experts, cudaStream_t stream) {
108182
// bitonic sort requires ncols to be power of 2
109183
const int ncols_pad = next_power_of_2(ncols);
@@ -117,20 +191,18 @@ static void argsort_f32_i32_cuda(const float * x, int * dst, const int ncols, co
117191

118192
if (order == GGML_SORT_ORDER_ASC) {
119193
if (min_experts >= 0 && min_experts < ncols && thresh_experts > 0) {
120-
k_argsort_f32_i32<GGML_SORT_ORDER_ASC, store_ser><<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad,
121-
{min_experts, thresh_experts});
194+
k_argsort_f32_T<GGML_SORT_ORDER_ASC, store_ser><<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad,
195+
ntop, {min_experts, thresh_experts});
122196
} else {
123-
k_argsort_f32_i32<GGML_SORT_ORDER_ASC, store><<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad, {});
197+
k_argsort_f32_T<GGML_SORT_ORDER_ASC, store><<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad, ntop, {});
124198
}
125-
//k_argsort_f32_i32<GGML_SORT_ORDER_ASC><<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad, min_experts, thresh_experts);
126199
} else if (order == GGML_SORT_ORDER_DESC) {
127200
if (min_experts >= 0 && min_experts < ncols && thresh_experts > 0) {
128-
k_argsort_f32_i32<GGML_SORT_ORDER_DESC, store_ser><<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad,
129-
{min_experts, thresh_experts});
201+
k_argsort_f32_T<GGML_SORT_ORDER_DESC, store_ser><<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad,
202+
ntop, {min_experts, thresh_experts});
130203
} else {
131-
k_argsort_f32_i32<GGML_SORT_ORDER_DESC, store><<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad, {});
204+
k_argsort_f32_T<GGML_SORT_ORDER_DESC, store><<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad, ntop, {});
132205
}
133-
//k_argsort_f32_i32<GGML_SORT_ORDER_DESC><<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad, min_experts, thresh_experts);
134206
} else {
135207
GGML_ABORT("fatal error");
136208
}
@@ -151,7 +223,7 @@ void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
151223

152224
enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
153225

154-
argsort_f32_i32_cuda(src0_d, (int *)dst_d, ncols, nrows, order, -1, 0.f, stream);
226+
argsort_f32_T_cuda(src0_d, (int *)dst_d, ncols, nrows, ncols, order, -1, 0.f, stream);
155227
}
156228

157229
void ggml_cuda_op_argsort_thresh(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
@@ -171,5 +243,70 @@ void ggml_cuda_op_argsort_thresh(ggml_backend_cuda_context & ctx, ggml_tensor *
171243
float thresh;
172244
memcpy(&thresh, dst->op_params + 1, sizeof(float));
173245

174-
argsort_f32_i32_cuda(src0_d, (int *)dst_d, ncols, nrows, GGML_SORT_ORDER_DESC, min_experts, thresh, stream);
246+
argsort_f32_T_cuda(src0_d, (int *)dst_d, ncols, nrows, ncols, GGML_SORT_ORDER_DESC, min_experts, thresh, stream);
247+
}
248+
249+
static void ggml_cuda_op_topk_sum(ggml_backend_cuda_context & ctx, const float * src, float * dst, int ncols, int nrows, int n_top_k) {
250+
251+
GGML_ASSERT(n_top_k <= ncols);
252+
253+
const int ncols_pad = next_power_of_2(ncols);
254+
255+
const dim3 block_dims(ncols_pad, 1, 1);
256+
const dim3 block_nums(1, nrows, 1);
257+
const size_t shared_mem = std::max(ncols_pad, WARP_SIZE) * sizeof(int);
258+
GGML_ASSERT(shared_mem <= ggml_cuda_info().devices[ggml_cuda_get_device()].smpb);
259+
260+
k_topk_sum<GGML_SORT_ORDER_DESC><<<block_nums, block_dims, shared_mem, ctx.stream()>>>(src, dst, ncols, ncols_pad, n_top_k);
261+
}
262+
263+
void ggml_cuda_op_grouped_topk(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
264+
auto src = dst->src[0];
265+
GGML_ASSERT(dst->type == GGML_TYPE_I32);
266+
GGML_ASSERT(src->type == GGML_TYPE_F32);
267+
GGML_ASSERT(ggml_nrows(src) == ggml_nrows(dst));
268+
269+
auto nrows = ggml_nrows(src);
270+
271+
int n_groups = dst->op_params[0];
272+
int n_top_groups = dst->op_params[1];
273+
int nk = dst->op_params[2];
274+
275+
int ne00 = src->ne[0];
276+
int ne0 = dst->ne[0];
277+
GGML_ASSERT(ne0 <= ne00);
278+
GGML_ASSERT(ne00%n_groups == 0);
279+
int n_per_group = ne00/n_groups;
280+
GGML_ASSERT(nk <= n_per_group);
281+
GGML_ASSERT(n_top_groups < n_groups);
282+
int n_discarded_groups = n_groups - n_top_groups;
283+
284+
#if 0
285+
ggml_cuda_pool_alloc<float> sorted_group_scores(ctx.pool(), nk*nrows*n_groups);
286+
argsort_f32_T_cuda((const float *)src->data, sorted_group_scores.get(), n_per_group, nrows*n_groups, nk,
287+
GGML_SORT_ORDER_DESC, -1, 0.0f, ctx.stream());
288+
CUDA_CHECK(cudaGetLastError());
289+
ggml_cuda_pool_alloc<float> group_scores(ctx.pool(), nrows*n_groups);
290+
sum_rows_f32_cuda((const float *)sorted_group_scores.get(), group_scores.get(), nk, nrows*n_groups, ctx.stream());
291+
CUDA_CHECK(cudaGetLastError());
292+
#else
293+
ggml_cuda_pool_alloc<float> group_scores(ctx.pool(), nrows*n_groups);
294+
ggml_cuda_op_topk_sum(ctx, (const float *)src->data, group_scores.get(), n_per_group, nrows*n_groups, nk);
295+
CUDA_CHECK(cudaGetLastError());
296+
#endif
297+
298+
ggml_cuda_pool_alloc<int> discarded_groups(ctx.pool(), nrows*n_discarded_groups);
299+
argsort_f32_T_cuda(group_scores.get(), discarded_groups.get(), n_groups, nrows, n_discarded_groups, GGML_SORT_ORDER_ASC, -1, 0.0f, ctx.stream());
300+
CUDA_CHECK(cudaGetLastError());
301+
302+
{
303+
const dim3 block_dims(WARP_SIZE, 1, 1);
304+
const dim3 block_nums(1, nrows, 1);
305+
cudaStream_t stream = ctx.stream();
306+
k_apply_mask<<<block_nums, block_dims, 0, ctx.stream()>>>((float *)src->data, discarded_groups.get(), n_discarded_groups, n_per_group, ne00);
307+
CUDA_CHECK(cudaGetLastError());
308+
}
309+
310+
argsort_f32_T_cuda((const float *)src->data, (int *)dst->data, ne00, nrows, ne0, GGML_SORT_ORDER_DESC, -1, 0.0f, ctx.stream());
311+
175312
}

ggml/src/ggml-cuda/argsort.cuh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,5 @@
99
void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
1010

1111
void ggml_cuda_op_argsort_thresh(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
12+
13+
void ggml_cuda_op_grouped_topk(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

ggml/src/ggml-cuda/sumrows.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ static __global__ void k_sum_rows_f32(const float * x, float * dst, const int nc
1616
}
1717
}
1818

19-
static void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
19+
void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
2020
const dim3 block_dims(WARP_SIZE, 1, 1);
2121
const dim3 block_nums(nrows, 1, 1);
2222
k_sum_rows_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);

ggml/src/ggml-cuda/sumrows.cuh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
11
#include "common.cuh"
22

33
void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
4+
5+
void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream);

0 commit comments

Comments
 (0)