Skip to content

Commit 98de698

Browse files
committed
Argsort: move to a separate file
1 parent c9f419f commit 98de698

File tree

4 files changed

+129
-129
lines changed

4 files changed

+129
-129
lines changed

ggml/src/ggml-sycl/argsort.cpp

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
#include "argsort.hpp"
2+
3+
template <typename T>
4+
static inline void ggml_sycl_swap(T & a, T & b) {
5+
T tmp = a;
6+
a = b;
7+
b = tmp;
8+
}
9+
10+
template <ggml_sort_order order>
11+
__dpct_inline__ static void k_argsort_f32_i32(const float * x, int * dst, const int ncols, int ncols_pad,
12+
const sycl::nd_item<3> & item_ct1, uint8_t * dpct_local) {
13+
// bitonic sort
14+
int col = item_ct1.get_local_id(2);
15+
int row = item_ct1.get_group(1);
16+
17+
if (col >= ncols_pad) {
18+
return;
19+
}
20+
21+
const float * x_row = x + row * ncols;
22+
auto dst_row = (int *) dpct_local;
23+
24+
// initialize indices
25+
dst_row[col] = col;
26+
27+
item_ct1.barrier(sycl::access::fence_space::local_space);
28+
29+
for (int k = 2; k <= ncols_pad; k *= 2) {
30+
for (int j = k / 2; j > 0; j /= 2) {
31+
int ixj = col ^ j;
32+
if (ixj > col) {
33+
if ((col & k) == 0) {
34+
if (dst_row[col] >= ncols ||
35+
(dst_row[ixj] < ncols &&
36+
(order == GGML_SORT_ORDER_ASC ? x_row[dst_row[col]] > x_row[dst_row[ixj]] :
37+
x_row[dst_row[col]] < x_row[dst_row[ixj]]))) {
38+
ggml_sycl_swap(dst_row[col], dst_row[ixj]);
39+
}
40+
} else {
41+
if (dst_row[ixj] >= ncols ||
42+
(dst_row[col] < ncols &&
43+
(order == GGML_SORT_ORDER_ASC ? x_row[dst_row[col]] < x_row[dst_row[ixj]] :
44+
x_row[dst_row[col]] > x_row[dst_row[ixj]]))) {
45+
ggml_sycl_swap(dst_row[col], dst_row[ixj]);
46+
}
47+
}
48+
}
49+
/*
50+
DPCT1118:1: SYCL group functions and algorithms must be encountered
51+
in converged control flow. You may need to adjust the code.
52+
*/
53+
item_ct1.barrier(sycl::access::fence_space::local_space);
54+
}
55+
}
56+
57+
// copy the result to dst without the padding
58+
if (col < ncols) {
59+
dst[row * ncols + col] = dst_row[col];
60+
}
61+
}
62+
63+
static void argsort_f32_i32_sycl(const float * x, int * dst, const int ncols, const int nrows, ggml_sort_order order,
64+
queue_ptr stream) {
65+
// bitonic sort requires ncols to be power of 2
66+
const int ncols_pad = next_power_of_2(ncols);
67+
68+
const sycl::range<3> block_dims(1, 1, ncols_pad);
69+
const sycl::range<3> block_nums(1, nrows, 1);
70+
const size_t shared_mem = ncols_pad * sizeof(int);
71+
72+
if (order == GGML_SORT_ORDER_ASC) {
73+
stream->submit([&](sycl::handler & cgh) {
74+
sycl::local_accessor<uint8_t, 1> dpct_local_acc_ct1(sycl::range<1>(shared_mem), cgh);
75+
76+
cgh.parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
77+
k_argsort_f32_i32<GGML_SORT_ORDER_ASC>(
78+
x, dst, ncols, ncols_pad, item_ct1,
79+
dpct_local_acc_ct1.get_multi_ptr<sycl::access::decorated::no>().get());
80+
});
81+
});
82+
} else if (order == GGML_SORT_ORDER_DESC) {
83+
stream->submit([&](sycl::handler & cgh) {
84+
sycl::local_accessor<uint8_t, 1> dpct_local_acc_ct1(sycl::range<1>(shared_mem), cgh);
85+
86+
cgh.parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
87+
k_argsort_f32_i32<GGML_SORT_ORDER_DESC>(
88+
x, dst, ncols, ncols_pad, item_ct1,
89+
dpct_local_acc_ct1.get_multi_ptr<sycl::access::decorated::no>().get());
90+
});
91+
});
92+
} else {
93+
GGML_ABORT("fatal error");
94+
}
95+
}
96+
97+
inline void ggml_sycl_op_argsort(ggml_backend_sycl_context & ctx, ggml_tensor * dst) try {
98+
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
99+
GGML_ASSERT(dst->type == GGML_TYPE_I32);
100+
101+
const int64_t ncols = dst->src[0]->ne[0];
102+
const int64_t nrows = ggml_nrows(dst->src[0]);
103+
104+
enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
105+
dpct::queue_ptr main_stream = ctx.stream();
106+
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
107+
int32_t * dst_dd = static_cast<int32_t *>(dst->data);
108+
109+
argsort_f32_i32_sycl(src0_dd, dst_dd, ncols, nrows, order, main_stream);
110+
} catch (const sycl::exception & exc) {
111+
std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl;
112+
std::exit(1);
113+
}
114+
115+
void ggml_sycl_argsort(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
116+
GGML_ASSERT(ggml_is_contiguous(dst->src[0]));
117+
GGML_SYCL_DEBUG("call %s\n", __func__);
118+
ggml_sycl_op_argsort(ctx, dst);
119+
GGML_SYCL_DEBUG("call %s done\n", __func__);
120+
}

ggml/src/ggml-sycl/argsort.hpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
#ifndef GGML_SYCL_ARGSORT_HPP
2+
#define GGML_SYCL_ARGSORT_HPP
3+
4+
#include "common.hpp"
5+
6+
void ggml_sycl_argsort(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
7+
8+
#endif // GGML_SYCL_ARGSORT_HPP

ggml/src/ggml-sycl/backend.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
#include "element_wise.hpp"
3232
#include "binbcast.hpp"
3333
#include "argmax.hpp"
34+
#include "argsort.hpp"
3435
#include "gla.hpp"
3536

3637
#endif // GGML_SYCL_BACKEND_HPP

ggml/src/ggml-sycl/ggml-sycl.cpp

Lines changed: 0 additions & 129 deletions
Original file line numberDiff line numberDiff line change
@@ -1730,70 +1730,6 @@ static void k_sum_rows_f32(const float * x, float * dst, const int ncols,
17301730
}
17311731

17321732

1733-
template<typename T>
1734-
static inline void ggml_sycl_swap(T & a, T & b) {
1735-
T tmp = a;
1736-
a = b;
1737-
b = tmp;
1738-
}
1739-
1740-
template <ggml_sort_order order>
1741-
__dpct_inline__ static void
1742-
k_argsort_f32_i32(const float *x, int *dst, const int ncols, int ncols_pad,
1743-
const sycl::nd_item<3> &item_ct1, uint8_t *dpct_local) {
1744-
// bitonic sort
1745-
int col = item_ct1.get_local_id(2);
1746-
int row = item_ct1.get_group(1);
1747-
1748-
if (col >= ncols_pad) {
1749-
return;
1750-
}
1751-
1752-
const float * x_row = x + row * ncols;
1753-
auto dst_row = (int *)dpct_local;
1754-
1755-
// initialize indices
1756-
dst_row[col] = col;
1757-
1758-
item_ct1.barrier(sycl::access::fence_space::local_space);
1759-
1760-
for (int k = 2; k <= ncols_pad; k *= 2) {
1761-
for (int j = k / 2; j > 0; j /= 2) {
1762-
int ixj = col ^ j;
1763-
if (ixj > col) {
1764-
if ((col & k) == 0) {
1765-
if (dst_row[col] >= ncols ||
1766-
(dst_row[ixj] < ncols && (order == GGML_SORT_ORDER_ASC ?
1767-
x_row[dst_row[col]] > x_row[dst_row[ixj]] :
1768-
x_row[dst_row[col]] < x_row[dst_row[ixj]]))
1769-
) {
1770-
ggml_sycl_swap(dst_row[col], dst_row[ixj]);
1771-
}
1772-
} else {
1773-
if (dst_row[ixj] >= ncols ||
1774-
(dst_row[col] < ncols && (order == GGML_SORT_ORDER_ASC ?
1775-
x_row[dst_row[col]] < x_row[dst_row[ixj]] :
1776-
x_row[dst_row[col]] > x_row[dst_row[ixj]]))
1777-
) {
1778-
ggml_sycl_swap(dst_row[col], dst_row[ixj]);
1779-
}
1780-
}
1781-
}
1782-
/*
1783-
DPCT1118:1: SYCL group functions and algorithms must be encountered
1784-
in converged control flow. You may need to adjust the code.
1785-
*/
1786-
item_ct1.barrier(sycl::access::fence_space::local_space);
1787-
}
1788-
}
1789-
1790-
// copy the result to dst without the padding
1791-
if (col < ncols) {
1792-
dst[row * ncols + col] = dst_row[col];
1793-
}
1794-
}
1795-
1796-
17971733
static void diag_mask_inf_f32(const float * x, float * dst, const int ncols, const int rows_per_channel, const int n_past,
17981734
const sycl::nd_item<3> &item_ct1) {
17991735
const int col = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
@@ -2304,49 +2240,6 @@ static int next_power_of_2(int x) {
23042240
return n;
23052241
}
23062242

2307-
static void argsort_f32_i32_sycl(const float *x, int *dst, const int ncols,
2308-
const int nrows, ggml_sort_order order,
2309-
queue_ptr stream) {
2310-
// bitonic sort requires ncols to be power of 2
2311-
const int ncols_pad = next_power_of_2(ncols);
2312-
2313-
const sycl::range<3> block_dims(1, 1, ncols_pad);
2314-
const sycl::range<3> block_nums(1, nrows, 1);
2315-
const size_t shared_mem = ncols_pad * sizeof(int);
2316-
2317-
if (order == GGML_SORT_ORDER_ASC) {
2318-
stream->submit([&](sycl::handler &cgh) {
2319-
sycl::local_accessor<uint8_t, 1> dpct_local_acc_ct1(
2320-
sycl::range<1>(shared_mem), cgh);
2321-
2322-
cgh.parallel_for(
2323-
sycl::nd_range<3>(block_nums * block_dims, block_dims),
2324-
[=](sycl::nd_item<3> item_ct1) {
2325-
k_argsort_f32_i32<GGML_SORT_ORDER_ASC>(
2326-
x, dst, ncols, ncols_pad, item_ct1,
2327-
dpct_local_acc_ct1.get_multi_ptr<sycl::access::decorated::no>()
2328-
.get());
2329-
});
2330-
});
2331-
} else if (order == GGML_SORT_ORDER_DESC) {
2332-
stream->submit([&](sycl::handler &cgh) {
2333-
sycl::local_accessor<uint8_t, 1> dpct_local_acc_ct1(
2334-
sycl::range<1>(shared_mem), cgh);
2335-
2336-
cgh.parallel_for(
2337-
sycl::nd_range<3>(block_nums * block_dims, block_dims),
2338-
[=](sycl::nd_item<3> item_ct1) {
2339-
k_argsort_f32_i32<GGML_SORT_ORDER_DESC>(
2340-
x, dst, ncols, ncols_pad, item_ct1,
2341-
dpct_local_acc_ct1.get_multi_ptr<sycl::access::decorated::no>()
2342-
.get());
2343-
});
2344-
});
2345-
} else {
2346-
GGML_ABORT("fatal error");
2347-
}
2348-
}
2349-
23502243
static void diag_mask_inf_f32_sycl(const float *x, float *dst,
23512244
const int ncols_x, const int nrows_x,
23522245
const int rows_per_channel, const int n_past,
@@ -2678,22 +2571,6 @@ inline void ggml_sycl_op_sum_rows(ggml_backend_sycl_context & ctx, ggml_tensor *
26782571
sum_rows_f32_sycl(src0_dd, dst_dd, ncols, nrows, main_stream);
26792572
}
26802573

2681-
inline void ggml_sycl_op_argsort(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
2682-
2683-
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
2684-
GGML_ASSERT(dst->type == GGML_TYPE_I32);
2685-
2686-
const int64_t ncols = dst->src[0]->ne[0];
2687-
const int64_t nrows = ggml_nrows(dst->src[0]);
2688-
2689-
enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
2690-
dpct::queue_ptr main_stream = ctx.stream();
2691-
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
2692-
int32_t * dst_dd = static_cast<int32_t *>(dst->data);
2693-
2694-
argsort_f32_i32_sycl(src0_dd, dst_dd, ncols, nrows, order, main_stream);
2695-
}
2696-
26972574
inline void ggml_sycl_op_diag_mask_inf(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
26982575

26992576
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
@@ -3758,12 +3635,6 @@ static void ggml_sycl_sum_rows(ggml_backend_sycl_context & ctx, ggml_tensor * ds
37583635
ggml_sycl_op_sum_rows(ctx, dst);
37593636
}
37603637

3761-
static void ggml_sycl_argsort(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
3762-
GGML_ASSERT(ggml_is_contiguous(dst->src[0]));
3763-
ggml_sycl_op_argsort(ctx, dst);
3764-
}
3765-
3766-
37673638
void ggml_sycl_set_main_device(const int main_device) try {
37683639
if (dpct::get_current_device_id() == static_cast<unsigned int> (main_device)) {
37693640
return;

0 commit comments

Comments
 (0)