@@ -39,6 +39,11 @@ static __dpct_inline__ T op_abs(T x) {
3939 return sycl::fabs (x);
4040}
4141
42+ template <typename T>
43+ static __dpct_inline__ T op_floor (T x) {
44+ return sycl::floor (x);
45+ }
46+
4247template <typename T>
4348static __dpct_inline__ T op_elu (T x) {
4449 return (x > static_cast <T>(0 .f )) ? x : sycl::expm1 (x);
@@ -164,6 +169,13 @@ static void unary_op_abs_kernel(const T * x, T * dst, const int k, const sycl::n
164169 }
165170}
166171
172+ template <typename T>
173+ static void unary_op_floor_kernel (const T * x, T * dst, const int k, const sycl::nd_item<1 > &item_ct1) {
174+ SYCL_GLOBAL_ID_LOOP (k, item_ct1) {
175+ dst[i] = op_floor (x[i]);
176+ }
177+ }
178+
167179template <typename T>
168180static void unary_op_elu_kernel (const T * x, T * dst, const int k, const sycl::nd_item<1 > &item_ct1) {
169181 SYCL_GLOBAL_ID_LOOP (k, item_ct1) {
@@ -661,6 +673,19 @@ static inline void ggml_sycl_op_abs(ggml_backend_sycl_context & ctx, ggml_tensor
661673 });
662674}
663675
676+ static inline void ggml_sycl_op_floor (ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
677+ ggml_sycl_detail::dispatch_ggml_sycl_op_unary (ctx, dst,
678+ [](const auto * src, auto * dst_ptr, int k_elements, queue_ptr stream) {
679+ const int num_blocks = ceil_div (k_elements, 256 );
680+ sycl_parallel_for (stream,
681+ sycl::nd_range<1 >(sycl::range<1 >(num_blocks) * sycl::range<1 >(256 ),
682+ sycl::range<1 >(256 )),
683+ [=](sycl::nd_item<1 > item_ct1) {
684+ unary_op_floor_kernel (src, dst_ptr, k_elements, item_ct1);
685+ });
686+ });
687+ }
688+
664689static inline void ggml_sycl_op_elu (ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
665690 ggml_sycl_detail::dispatch_ggml_sycl_op_unary (ctx, dst,
666691 [](const auto * src, auto * dst_ptr, int k_elements, queue_ptr stream) {
@@ -1129,6 +1154,11 @@ void ggml_sycl_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
11291154 ggml_sycl_op_clamp (ctx, dst);
11301155}
11311156
1157+ void ggml_sycl_floor (ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1158+ scope_op_debug_print scope_dbg_print (__func__, dst, /* num_src=*/ 1 );
1159+ ggml_sycl_op_floor (ctx, dst);
1160+ }
1161+
11321162void ggml_sycl_sgn (ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
11331163 scope_op_debug_print scope_dbg_print (__func__, dst, /* num_src=*/ 1 );
11341164 ggml_sycl_op_sgn (ctx, dst);
0 commit comments