Skip to content

Commit 9db3e6f

Browse files
committed
Add FLOOR unary op with SYCL support
Implemented CPU + SYCL backends
1 parent 0a2a384 commit 9db3e6f

File tree

8 files changed

+75
-1
lines changed

8 files changed

+75
-1
lines changed

ggml/include/ggml.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -559,6 +559,7 @@ extern "C" {
559559

560560
enum ggml_unary_op {
561561
GGML_UNARY_OP_ABS,
562+
GGML_UNARY_OP_FLOOR,
562563
GGML_UNARY_OP_SGN,
563564
GGML_UNARY_OP_NEG,
564565
GGML_UNARY_OP_STEP,
@@ -1028,6 +1029,15 @@ extern "C" {
10281029
struct ggml_context * ctx,
10291030
struct ggml_tensor * a);
10301031

1032+
GGML_API struct ggml_tensor * ggml_floor(
1033+
struct ggml_context * ctx,
1034+
struct ggml_tensor * a);
1035+
1036+
GGML_API struct ggml_tensor * ggml_floor_inplace(
1037+
struct ggml_context * ctx,
1038+
struct ggml_tensor * a);
1039+
1040+
10311041
GGML_API struct ggml_tensor * ggml_sgn(
10321042
struct ggml_context * ctx,
10331043
struct ggml_tensor * a);

ggml/src/ggml-cpu/ops.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9336,6 +9336,10 @@ void ggml_compute_forward_unary(
93369336
{
93379337
ggml_compute_forward_abs(params, dst);
93389338
} break;
9339+
case GGML_UNARY_OP_FLOOR:
9340+
{
9341+
ggml_compute_forward_floor(params, dst);
9342+
} break;
93399343
case GGML_UNARY_OP_SGN:
93409344
{
93419345
ggml_compute_forward_sgn(params, dst);

ggml/src/ggml-cpu/unary-ops.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@ static inline float op_abs(float x) {
44
return fabsf(x);
55
}
66

7+
static inline float op_floor(float x) {
8+
return floorf(x);
9+
}
10+
711
static inline float op_sgn(float x) {
812
return (x > 0.f) ? 1.f : ((x < 0.f) ? -1.f : 0.f);
913
}
@@ -125,6 +129,10 @@ void ggml_compute_forward_abs(const ggml_compute_params * params, ggml_tensor *
125129
unary_op<op_abs>(params, dst);
126130
}
127131

132+
void ggml_compute_forward_floor(const ggml_compute_params * params, ggml_tensor * dst) {
133+
unary_op<op_floor>(params, dst);
134+
}
135+
128136
void ggml_compute_forward_sgn(const ggml_compute_params * params, ggml_tensor * dst) {
129137
unary_op<op_sgn>(params, dst);
130138
}

ggml/src/ggml-cpu/unary-ops.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ extern "C" {
77
#endif
88

99
void ggml_compute_forward_abs(const struct ggml_compute_params * params, struct ggml_tensor * dst);
10+
void ggml_compute_forward_floor(const struct ggml_compute_params * params,struct ggml_tensor * dst);
1011
void ggml_compute_forward_sgn(const struct ggml_compute_params * params, struct ggml_tensor * dst);
1112
void ggml_compute_forward_neg(const struct ggml_compute_params * params, struct ggml_tensor * dst);
1213
void ggml_compute_forward_step(const struct ggml_compute_params * params, struct ggml_tensor * dst);

ggml/src/ggml-sycl/element_wise.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,11 @@ static __dpct_inline__ T op_abs(T x) {
3939
return sycl::fabs(x);
4040
}
4141

42+
template<typename T>
43+
static __dpct_inline__ T op_floor(T x) {
44+
return sycl::floor(x);
45+
}
46+
4247
template<typename T>
4348
static __dpct_inline__ T op_elu(T x) {
4449
return (x > static_cast<T>(0.f)) ? x : sycl::expm1(x);
@@ -164,6 +169,13 @@ static void unary_op_abs_kernel(const T * x, T * dst, const int k, const sycl::n
164169
}
165170
}
166171

172+
template<typename T>
173+
static void unary_op_floor_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
174+
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
175+
dst[i] = op_floor(x[i]);
176+
}
177+
}
178+
167179
template<typename T>
168180
static void unary_op_elu_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
169181
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
@@ -661,6 +673,19 @@ static inline void ggml_sycl_op_abs(ggml_backend_sycl_context & ctx, ggml_tensor
661673
});
662674
}
663675

676+
static inline void ggml_sycl_op_floor(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
677+
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
678+
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
679+
const int num_blocks = ceil_div(k_elements, 256);
680+
sycl_parallel_for(stream,
681+
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256),
682+
sycl::range<1>(256)),
683+
[=](sycl::nd_item<1> item_ct1) {
684+
unary_op_floor_kernel(src, dst_ptr, k_elements, item_ct1);
685+
});
686+
});
687+
}
688+
664689
static inline void ggml_sycl_op_elu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
665690
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
666691
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
@@ -1129,6 +1154,11 @@ void ggml_sycl_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
11291154
ggml_sycl_op_clamp(ctx, dst);
11301155
}
11311156

1157+
void ggml_sycl_floor(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1158+
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
1159+
ggml_sycl_op_floor(ctx, dst);
1160+
}
1161+
11321162
void ggml_sycl_sgn(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
11331163
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
11341164
ggml_sycl_op_sgn(ctx, dst);

ggml/src/ggml-sycl/element_wise.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,8 @@ void ggml_sycl_sgn(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
7575

7676
void ggml_sycl_abs(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
7777

78+
void ggml_sycl_floor(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
79+
7880
void ggml_sycl_elu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
7981

8082
void ggml_sycl_geglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);

ggml/src/ggml-sycl/ggml-sycl.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3626,6 +3626,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
36263626
case GGML_UNARY_OP_ABS:
36273627
ggml_sycl_abs(ctx, dst);
36283628
break;
3629+
case GGML_UNARY_OP_FLOOR:
3630+
ggml_sycl_floor(ctx, dst);
3631+
break;
36293632
case GGML_UNARY_OP_ELU:
36303633
ggml_sycl_elu(ctx, dst);
36313634
break;
@@ -4182,6 +4185,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
41824185
case GGML_UNARY_OP_SGN:
41834186
case GGML_UNARY_OP_ABS:
41844187
case GGML_UNARY_OP_ELU:
4188+
case GGML_UNARY_OP_FLOOR:
41854189
#if defined (GGML_SYCL_F16)
41864190
return ggml_is_contiguous(op->src[0]) && (op->type == op->src[0]->type);
41874191
#else

ggml/src/ggml.c

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1127,6 +1127,7 @@ static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
11271127

11281128
static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = {
11291129
"ABS",
1130+
"FLOOR",
11301131
"SGN",
11311132
"NEG",
11321133
"STEP",
@@ -1143,7 +1144,7 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = {
11431144
"GELU_ERF",
11441145
};
11451146

1146-
static_assert(GGML_UNARY_OP_COUNT == 15, "GGML_UNARY_OP_COUNT != 15");
1147+
static_assert(GGML_UNARY_OP_COUNT == 16, "GGML_UNARY_OP_COUNT != 16");
11471148

11481149

11491150
static const char * GGML_GLU_OP_NAME[GGML_GLU_OP_COUNT] = {
@@ -2479,6 +2480,20 @@ struct ggml_tensor * ggml_abs_inplace(
24792480
return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_ABS);
24802481
}
24812482

2483+
// ggml_floor
2484+
2485+
struct ggml_tensor * ggml_floor(
2486+
struct ggml_context * ctx,
2487+
struct ggml_tensor * a) {
2488+
return ggml_unary(ctx, a, GGML_UNARY_OP_FLOOR);
2489+
}
2490+
2491+
struct ggml_tensor * ggml_floor_inplace(
2492+
struct ggml_context * ctx,
2493+
struct ggml_tensor * a) {
2494+
return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_FLOOR);
2495+
}
2496+
24822497
// ggml_sgn
24832498

24842499
struct ggml_tensor * ggml_sgn(

0 commit comments

Comments
 (0)