Skip to content

Commit c9f419f

Browse files
committed
ARGMAX: move to a separate file
1 parent 1d5ad8d commit c9f419f

File tree

4 files changed

+82
-68
lines changed

4 files changed

+82
-68
lines changed

ggml/src/ggml-sycl/argmax.cpp

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
#include "argmax.hpp"
2+
3+
static void argmax_f32_i32_sycl(const float * x, int * dst, const int ncols, const int nrows, queue_ptr stream) {
4+
const sycl::range<3> block_dims(1, 1, SYCL_ARGMAX_BLOCK_SIZE);
5+
const sycl::range<3> block_nums(1, nrows, 1);
6+
const size_t shared_mem = 256 * sizeof(float);
7+
8+
stream->submit([&](sycl::handler & cgh) {
9+
sycl::local_accessor<float, 1> shared_data(sycl::range<1>(shared_mem / sizeof(float)), cgh);
10+
sycl::local_accessor<int, 1> shared_indices(sycl::range<1>(shared_mem / sizeof(float)), cgh);
11+
12+
cgh.parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
13+
const int tid = item_ct1.get_local_id(2);
14+
const int row = item_ct1.get_global_id(1);
15+
16+
float max_val = -INFINITY;
17+
int max_idx = -1;
18+
19+
for (int col = tid; col < ncols; col += 256) {
20+
float val = x[row * ncols + col];
21+
if (val > max_val) {
22+
max_val = val;
23+
max_idx = col;
24+
}
25+
}
26+
27+
shared_data[tid] = max_val;
28+
shared_indices[tid] = max_idx;
29+
item_ct1.barrier(sycl::access::fence_space::local_space);
30+
31+
for (int stride = 256 / 2; stride > 0; stride >>= 1) {
32+
if (tid < stride) {
33+
float val1 = shared_data[tid];
34+
float val2 = shared_data[tid + stride];
35+
if (val2 > val1) {
36+
shared_data[tid] = val2;
37+
shared_indices[tid] = shared_indices[tid + stride];
38+
}
39+
}
40+
item_ct1.barrier(sycl::access::fence_space::local_space);
41+
}
42+
43+
if (tid == 0) {
44+
dst[row] = shared_indices[0];
45+
}
46+
});
47+
});
48+
}
49+
50+
void ggml_sycl_op_argmax(ggml_backend_sycl_context & ctx, ggml_tensor * dst) try {
51+
GGML_ASSERT(ggml_is_contiguous(dst->src[0]));
52+
53+
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
54+
GGML_ASSERT(dst->type == GGML_TYPE_I32);
55+
56+
const int64_t ncols = dst->src[0]->ne[0];
57+
const int64_t nrows = ggml_nrows(dst->src[0]);
58+
59+
dpct::queue_ptr main_stream = ctx.stream();
60+
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
61+
int32_t * dst_dd = static_cast<int32_t *>(dst->data);
62+
argmax_f32_i32_sycl(src0_dd, dst_dd, ncols, nrows, main_stream);
63+
} catch (const sycl::exception & exc) {
64+
std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl;
65+
std::exit(1);
66+
}
67+
68+
void ggml_sycl_argmax(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
69+
GGML_ASSERT(ggml_is_contiguous(dst->src[0]));
70+
GGML_SYCL_DEBUG("call %s\n", __func__);
71+
ggml_sycl_op_argmax(ctx, dst);
72+
GGML_SYCL_DEBUG("call %s done\n", __func__);
73+
}

ggml/src/ggml-sycl/argmax.hpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
#ifndef GGML_SYCL_ARGMAX_HPP
2+
#define GGML_SYCL_ARGMAX_HPP
3+
4+
#include "common.hpp"
5+
6+
void ggml_sycl_op_argmax(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
7+
8+
#endif // GGML_SYCL_ARGMAX_HPP

ggml/src/ggml-sycl/backend.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
#include "outprod.hpp"
3131
#include "element_wise.hpp"
3232
#include "binbcast.hpp"
33+
#include "argmax.hpp"
3334
#include "gla.hpp"
3435

3536
#endif // GGML_SYCL_BACKEND_HPP

ggml/src/ggml-sycl/ggml-sycl.cpp

Lines changed: 0 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -2347,58 +2347,6 @@ static void argsort_f32_i32_sycl(const float *x, int *dst, const int ncols,
23472347
}
23482348
}
23492349

2350-
static void argmax_f32_i32_sycl(const float *x, int *dst, const int ncols,
2351-
const int nrows, queue_ptr stream) {
2352-
const sycl::range<3> block_dims(1, 1, SYCL_ARGMAX_BLOCK_SIZE);
2353-
const sycl::range<3> block_nums(1, nrows, 1);
2354-
const size_t shared_mem = 256 * sizeof(float);
2355-
2356-
stream->submit([&](sycl::handler &cgh) {
2357-
sycl::local_accessor<float, 1> shared_data(
2358-
sycl::range<1>(shared_mem/sizeof(float)), cgh);
2359-
sycl::local_accessor<int, 1> shared_indices(
2360-
sycl::range<1>(shared_mem/sizeof(float)), cgh);
2361-
2362-
cgh.parallel_for(
2363-
sycl::nd_range<3>(block_nums * block_dims, block_dims),
2364-
[=](sycl::nd_item<3> item_ct1) {
2365-
const int tid = item_ct1.get_local_id(2);
2366-
const int row = item_ct1.get_global_id(1);
2367-
2368-
float max_val = -INFINITY;
2369-
int max_idx = -1;
2370-
2371-
for (int col = tid; col < ncols; col += 256) {
2372-
float val = x[row * ncols + col];
2373-
if (val > max_val) {
2374-
max_val = val;
2375-
max_idx = col;
2376-
}
2377-
}
2378-
2379-
shared_data[tid] = max_val;
2380-
shared_indices[tid] = max_idx;
2381-
item_ct1.barrier(sycl::access::fence_space::local_space);
2382-
2383-
for (int stride = 256/2; stride > 0; stride >>= 1) {
2384-
if (tid < stride) {
2385-
float val1 = shared_data[tid];
2386-
float val2 = shared_data[tid + stride];
2387-
if (val2 > val1) {
2388-
shared_data[tid] = val2;
2389-
shared_indices[tid] = shared_indices[tid + stride];
2390-
}
2391-
}
2392-
item_ct1.barrier(sycl::access::fence_space::local_space);
2393-
}
2394-
2395-
2396-
if (tid == 0) {
2397-
dst[row] = shared_indices[0];
2398-
}
2399-
});
2400-
});
2401-
}
24022350
static void diag_mask_inf_f32_sycl(const float *x, float *dst,
24032351
const int ncols_x, const int nrows_x,
24042352
const int rows_per_channel, const int n_past,
@@ -2746,22 +2694,6 @@ inline void ggml_sycl_op_argsort(ggml_backend_sycl_context & ctx, ggml_tensor *
27462694
argsort_f32_i32_sycl(src0_dd, dst_dd, ncols, nrows, order, main_stream);
27472695
}
27482696

2749-
inline void ggml_sycl_op_argmax(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
2750-
GGML_ASSERT(ggml_is_contiguous(dst->src[0]));
2751-
2752-
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
2753-
GGML_ASSERT(dst->type == GGML_TYPE_I32);
2754-
2755-
const int64_t ncols = dst->src[0]->ne[0];
2756-
const int64_t nrows = ggml_nrows(dst->src[0]);
2757-
2758-
dpct::queue_ptr main_stream = ctx.stream();
2759-
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
2760-
int32_t * dst_dd = static_cast<int32_t *>(dst->data);
2761-
2762-
argmax_f32_i32_sycl(src0_dd, dst_dd, ncols, nrows, main_stream);
2763-
}
2764-
27652697
inline void ggml_sycl_op_diag_mask_inf(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
27662698

27672699
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);

0 commit comments

Comments
 (0)