Skip to content

Commit 6df9afd

Browse files
committed
ops: add ROUND operator support for CPU and SYCL
1 parent 0a2a384 commit 6df9afd

File tree

12 files changed

+122
-1
lines changed

12 files changed

+122
-1
lines changed

docs/ops.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ Legend:
7676
| ROLL ||||||||||
7777
| ROPE || 🟡 ||||||||
7878
| ROPE_BACK ||||||||||
79+
| ROUND ||||||||||
7980
| RWKV_WKV6 ||||||||||
8081
| RWKV_WKV7 ||||||||||
8182
| SCALE || 🟡 ||||||||

docs/ops/CPU.csv

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
"backend_name","op_name","op_params","test_mode","supported","error_message","backend_reg_name"
22
"CPU","ABS","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","CPU"
33
"CPU","ABS","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","CPU"
4+
"CPU","ROUND","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","CPU"
5+
"CPU","ROUND","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","CPU"
46
"CPU","SGN","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","CPU"
57
"CPU","SGN","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","CPU"
68
"CPU","NEG","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","CPU"
@@ -61,6 +63,8 @@
6163
"CPU","GELU_ERF","type=f16,ne_a=[5,7,11,13],v=1","support","1","yes","CPU"
6264
"CPU","ABS","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","CPU"
6365
"CPU","ABS","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","CPU"
66+
"CPU","ROUND","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","CPU"
67+
"CPU","ROUND","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","CPU"
6468
"CPU","SGN","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","CPU"
6569
"CPU","SGN","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","CPU"
6670
"CPU","NEG","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","CPU"

docs/ops/SYCL.csv

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
"backend_name","op_name","op_params","test_mode","supported","error_message","backend_reg_name"
22
"SYCL0","ABS","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","SYCL"
33
"SYCL0","ABS","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","SYCL"
4+
"SYCL0","ROUND","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","SYCL"
5+
"SYCL0","ROUND","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","SYCL"
46
"SYCL0","SGN","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","SYCL"
57
"SYCL0","SGN","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","SYCL"
68
"SYCL0","NEG","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","SYCL"
@@ -61,6 +63,8 @@
6163
"SYCL0","GELU_ERF","type=f16,ne_a=[5,7,11,13],v=1","support","0","no","SYCL"
6264
"SYCL0","ABS","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","SYCL"
6365
"SYCL0","ABS","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","SYCL"
66+
"SYCL0","ROUND","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","SYCL"
67+
"SYCL0","ROUND","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","SYCL"
6468
"SYCL0","SGN","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","SYCL"
6569
"SYCL0","SGN","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","SYCL"
6670
"SYCL0","NEG","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","SYCL"

ggml/include/ggml.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -559,6 +559,7 @@ extern "C" {
559559

560560
enum ggml_unary_op {
561561
GGML_UNARY_OP_ABS,
562+
GGML_UNARY_OP_ROUND,
562563
GGML_UNARY_OP_SGN,
563564
GGML_UNARY_OP_NEG,
564565
GGML_UNARY_OP_STEP,
@@ -1027,6 +1028,13 @@ extern "C" {
10271028
GGML_API struct ggml_tensor * ggml_abs_inplace(
10281029
struct ggml_context * ctx,
10291030
struct ggml_tensor * a);
1031+
GGML_API struct ggml_tensor * ggml_round(
1032+
struct ggml_context * ctx,
1033+
struct ggml_tensor * a);
1034+
1035+
GGML_API struct ggml_tensor * ggml_round_inplace(
1036+
struct ggml_context * ctx,
1037+
struct ggml_tensor * a);
10301038

10311039
GGML_API struct ggml_tensor * ggml_sgn(
10321040
struct ggml_context * ctx,

ggml/src/ggml-cpu/ops.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9336,6 +9336,10 @@ void ggml_compute_forward_unary(
93369336
{
93379337
ggml_compute_forward_abs(params, dst);
93389338
} break;
9339+
case GGML_UNARY_OP_ROUND:
9340+
{
9341+
ggml_compute_forward_round(params, dst);
9342+
} break;
93399343
case GGML_UNARY_OP_SGN:
93409344
{
93419345
ggml_compute_forward_sgn(params, dst);

ggml/src/ggml-cpu/unary-ops.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@ static inline float op_abs(float x) {
44
return fabsf(x);
55
}
66

7+
static inline float op_round(float x) {
8+
return roundf(x);
9+
}
10+
711
static inline float op_sgn(float x) {
812
return (x > 0.f) ? 1.f : ((x < 0.f) ? -1.f : 0.f);
913
}
@@ -125,6 +129,10 @@ void ggml_compute_forward_abs(const ggml_compute_params * params, ggml_tensor *
125129
unary_op<op_abs>(params, dst);
126130
}
127131

132+
void ggml_compute_forward_round(const ggml_compute_params * params, ggml_tensor * dst) {
133+
unary_op<op_round>(params, dst);
134+
}
135+
128136
void ggml_compute_forward_sgn(const ggml_compute_params * params, ggml_tensor * dst) {
129137
unary_op<op_sgn>(params, dst);
130138
}

ggml/src/ggml-cpu/unary-ops.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ extern "C" {
77
#endif
88

99
void ggml_compute_forward_abs(const struct ggml_compute_params * params, struct ggml_tensor * dst);
10+
void ggml_compute_forward_round(const struct ggml_compute_params * params, struct ggml_tensor * dst);
1011
void ggml_compute_forward_sgn(const struct ggml_compute_params * params, struct ggml_tensor * dst);
1112
void ggml_compute_forward_neg(const struct ggml_compute_params * params, struct ggml_tensor * dst);
1213
void ggml_compute_forward_step(const struct ggml_compute_params * params, struct ggml_tensor * dst);

ggml/src/ggml-sycl/element_wise.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,11 @@ static __dpct_inline__ T op_abs(T x) {
3939
return sycl::fabs(x);
4040
}
4141

42+
template<typename T>
43+
static __dpct_inline__ T op_round(T x) {
44+
return sycl::round(x);
45+
}
46+
4247
template<typename T>
4348
static __dpct_inline__ T op_elu(T x) {
4449
return (x > static_cast<T>(0.f)) ? x : sycl::expm1(x);
@@ -164,6 +169,13 @@ static void unary_op_abs_kernel(const T * x, T * dst, const int k, const sycl::n
164169
}
165170
}
166171

172+
template<typename T>
173+
static void unary_op_round_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
174+
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
175+
dst[i] = op_round(x[i]);
176+
}
177+
}
178+
167179
template<typename T>
168180
static void unary_op_elu_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
169181
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
@@ -661,6 +673,19 @@ static inline void ggml_sycl_op_abs(ggml_backend_sycl_context & ctx, ggml_tensor
661673
});
662674
}
663675

676+
static inline void ggml_sycl_op_round(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
677+
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
678+
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
679+
const int num_blocks = ceil_div(k_elements, 256);
680+
sycl_parallel_for(stream,
681+
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256),
682+
sycl::range<1>(256)),
683+
[=](sycl::nd_item<1> item_ct1) {
684+
unary_op_round_kernel(src, dst_ptr, k_elements, item_ct1);
685+
});
686+
});
687+
}
688+
664689
static inline void ggml_sycl_op_elu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
665690
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
666691
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
@@ -1139,6 +1164,11 @@ void ggml_sycl_abs(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
11391164
ggml_sycl_op_abs(ctx, dst);
11401165
}
11411166

1167+
void ggml_sycl_round(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1168+
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
1169+
ggml_sycl_op_round(ctx, dst);
1170+
}
1171+
11421172
void ggml_sycl_elu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
11431173
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
11441174
ggml_sycl_op_elu(ctx, dst);

ggml/src/ggml-sycl/element_wise.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,8 @@ void ggml_sycl_sgn(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
7575

7676
void ggml_sycl_abs(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
7777

78+
void ggml_sycl_round(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
79+
7880
void ggml_sycl_elu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
7981

8082
void ggml_sycl_geglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);

ggml/src/ggml-sycl/ggml-sycl.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3626,6 +3626,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
36263626
case GGML_UNARY_OP_ABS:
36273627
ggml_sycl_abs(ctx, dst);
36283628
break;
3629+
case GGML_UNARY_OP_ROUND:
3630+
ggml_sycl_round(ctx, dst);
3631+
break;
36293632
case GGML_UNARY_OP_ELU:
36303633
ggml_sycl_elu(ctx, dst);
36313634
break;
@@ -4181,6 +4184,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
41814184
case GGML_UNARY_OP_EXP:
41824185
case GGML_UNARY_OP_SGN:
41834186
case GGML_UNARY_OP_ABS:
4187+
case GGML_UNARY_OP_ROUND:
41844188
case GGML_UNARY_OP_ELU:
41854189
#if defined (GGML_SYCL_F16)
41864190
return ggml_is_contiguous(op->src[0]) && (op->type == op->src[0]->type);

0 commit comments

Comments
 (0)