@@ -1730,70 +1730,6 @@ static void k_sum_rows_f32(const float * x, float * dst, const int ncols,
17301730}
17311731
17321732
1733- template <typename T>
1734- static inline void ggml_sycl_swap (T & a, T & b) {
1735- T tmp = a;
1736- a = b;
1737- b = tmp;
1738- }
1739-
1740- template <ggml_sort_order order>
1741- __dpct_inline__ static void
1742- k_argsort_f32_i32 (const float *x, int *dst, const int ncols, int ncols_pad,
1743- const sycl::nd_item<3 > &item_ct1, uint8_t *dpct_local) {
1744- // bitonic sort
1745- int col = item_ct1.get_local_id (2 );
1746- int row = item_ct1.get_group (1 );
1747-
1748- if (col >= ncols_pad) {
1749- return ;
1750- }
1751-
1752- const float * x_row = x + row * ncols;
1753- auto dst_row = (int *)dpct_local;
1754-
1755- // initialize indices
1756- dst_row[col] = col;
1757-
1758- item_ct1.barrier (sycl::access::fence_space::local_space);
1759-
1760- for (int k = 2 ; k <= ncols_pad; k *= 2 ) {
1761- for (int j = k / 2 ; j > 0 ; j /= 2 ) {
1762- int ixj = col ^ j;
1763- if (ixj > col) {
1764- if ((col & k) == 0 ) {
1765- if (dst_row[col] >= ncols ||
1766- (dst_row[ixj] < ncols && (order == GGML_SORT_ORDER_ASC ?
1767- x_row[dst_row[col]] > x_row[dst_row[ixj]] :
1768- x_row[dst_row[col]] < x_row[dst_row[ixj]]))
1769- ) {
1770- ggml_sycl_swap (dst_row[col], dst_row[ixj]);
1771- }
1772- } else {
1773- if (dst_row[ixj] >= ncols ||
1774- (dst_row[col] < ncols && (order == GGML_SORT_ORDER_ASC ?
1775- x_row[dst_row[col]] < x_row[dst_row[ixj]] :
1776- x_row[dst_row[col]] > x_row[dst_row[ixj]]))
1777- ) {
1778- ggml_sycl_swap (dst_row[col], dst_row[ixj]);
1779- }
1780- }
1781- }
1782- /*
1783- DPCT1118:1: SYCL group functions and algorithms must be encountered
1784- in converged control flow. You may need to adjust the code.
1785- */
1786- item_ct1.barrier (sycl::access::fence_space::local_space);
1787- }
1788- }
1789-
1790- // copy the result to dst without the padding
1791- if (col < ncols) {
1792- dst[row * ncols + col] = dst_row[col];
1793- }
1794- }
1795-
1796-
17971733static void diag_mask_inf_f32 (const float * x, float * dst, const int ncols, const int rows_per_channel, const int n_past,
17981734 const sycl::nd_item<3 > &item_ct1) {
17991735 const int col = item_ct1.get_local_range (1 ) * item_ct1.get_group (1 ) +
@@ -2304,49 +2240,6 @@ static int next_power_of_2(int x) {
23042240 return n;
23052241}
23062242
2307- static void argsort_f32_i32_sycl (const float *x, int *dst, const int ncols,
2308- const int nrows, ggml_sort_order order,
2309- queue_ptr stream) {
2310- // bitonic sort requires ncols to be power of 2
2311- const int ncols_pad = next_power_of_2 (ncols);
2312-
2313- const sycl::range<3 > block_dims (1 , 1 , ncols_pad);
2314- const sycl::range<3 > block_nums (1 , nrows, 1 );
2315- const size_t shared_mem = ncols_pad * sizeof (int );
2316-
2317- if (order == GGML_SORT_ORDER_ASC) {
2318- stream->submit ([&](sycl::handler &cgh) {
2319- sycl::local_accessor<uint8_t , 1 > dpct_local_acc_ct1 (
2320- sycl::range<1 >(shared_mem), cgh);
2321-
2322- cgh.parallel_for (
2323- sycl::nd_range<3 >(block_nums * block_dims, block_dims),
2324- [=](sycl::nd_item<3 > item_ct1) {
2325- k_argsort_f32_i32<GGML_SORT_ORDER_ASC>(
2326- x, dst, ncols, ncols_pad, item_ct1,
2327- dpct_local_acc_ct1.get_multi_ptr <sycl::access::decorated::no>()
2328- .get ());
2329- });
2330- });
2331- } else if (order == GGML_SORT_ORDER_DESC) {
2332- stream->submit ([&](sycl::handler &cgh) {
2333- sycl::local_accessor<uint8_t , 1 > dpct_local_acc_ct1 (
2334- sycl::range<1 >(shared_mem), cgh);
2335-
2336- cgh.parallel_for (
2337- sycl::nd_range<3 >(block_nums * block_dims, block_dims),
2338- [=](sycl::nd_item<3 > item_ct1) {
2339- k_argsort_f32_i32<GGML_SORT_ORDER_DESC>(
2340- x, dst, ncols, ncols_pad, item_ct1,
2341- dpct_local_acc_ct1.get_multi_ptr <sycl::access::decorated::no>()
2342- .get ());
2343- });
2344- });
2345- } else {
2346- GGML_ABORT (" fatal error" );
2347- }
2348- }
2349-
23502243static void diag_mask_inf_f32_sycl (const float *x, float *dst,
23512244 const int ncols_x, const int nrows_x,
23522245 const int rows_per_channel, const int n_past,
@@ -2678,22 +2571,6 @@ inline void ggml_sycl_op_sum_rows(ggml_backend_sycl_context & ctx, ggml_tensor *
26782571 sum_rows_f32_sycl (src0_dd, dst_dd, ncols, nrows, main_stream);
26792572}
26802573
2681- inline void ggml_sycl_op_argsort (ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
2682-
2683- GGML_ASSERT (dst->src [0 ]->type == GGML_TYPE_F32);
2684- GGML_ASSERT (dst->type == GGML_TYPE_I32);
2685-
2686- const int64_t ncols = dst->src [0 ]->ne [0 ];
2687- const int64_t nrows = ggml_nrows (dst->src [0 ]);
2688-
2689- enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params [0 ];
2690- dpct::queue_ptr main_stream = ctx.stream ();
2691- const float * src0_dd = static_cast <const float *>(dst->src [0 ]->data );
2692- int32_t * dst_dd = static_cast <int32_t *>(dst->data );
2693-
2694- argsort_f32_i32_sycl (src0_dd, dst_dd, ncols, nrows, order, main_stream);
2695- }
2696-
26972574inline void ggml_sycl_op_diag_mask_inf (ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
26982575
26992576 GGML_ASSERT (dst->src [0 ]->type == GGML_TYPE_F32);
@@ -3758,12 +3635,6 @@ static void ggml_sycl_sum_rows(ggml_backend_sycl_context & ctx, ggml_tensor * ds
37583635 ggml_sycl_op_sum_rows (ctx, dst);
37593636}
37603637
3761- static void ggml_sycl_argsort (ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
3762- GGML_ASSERT (ggml_is_contiguous (dst->src [0 ]));
3763- ggml_sycl_op_argsort (ctx, dst);
3764- }
3765-
3766-
37673638void ggml_sycl_set_main_device (const int main_device) try {
37683639 if (dpct::get_current_device_id () == static_cast <unsigned int > (main_device)) {
37693640 return ;
0 commit comments