11#include " common.hpp"
22#include " ggml.h"
33#include " element_wise.hpp"
4+ #include < sycl/detail/builtins/builtins.hpp>
5+ #include < sycl/nd_item.hpp>
6+ #include < sycl/nd_range.hpp>
7+ #include < sycl/range.hpp>
48
59static void acc_f32 (const float * x, const float * y, float * dst, const int ne,
610 const int ne10, const int ne11, const int ne12,
@@ -21,6 +25,33 @@ static void acc_f32(const float * x, const float * y, float * dst, const int ne,
2125 }
2226}
2327
28+ template <typename T>
29+ static void sgn (const T * x, T * dst, const int k, const sycl::nd_item<3 > &item_ct1) {
30+ const int i = item_ct1.get_global_id (2 );
31+ if (i >= k) {
32+ return ;
33+ }
34+ dst[i] = x[i] > static_cast <T>(0 .f ) ? static_cast <T>(1 .f ) : ((x[i] < static_cast <T>(0 .f ) ? static_cast <T>(-1 .f ) : static_cast <T>(0 .f )));
35+ }
36+
37+ template <typename T>
38+ static void abs_op (const T * x, T * dst, const int k, const sycl::nd_item<3 > &item_ct1) {
39+ const int i = item_ct1.get_global_id (2 );
40+ if (i >= k) {
41+ return ;
42+ }
43+ dst[i] = sycl::fabs (x[i]);
44+ }
45+
46+ template <typename T>
47+ static void elu_op (const T * x, T * dst, const int k, const sycl::nd_item<3 > &item_ct1) {
48+ const int i = item_ct1.get_global_id (2 );
49+ if (i >= k) {
50+ return ;
51+ }
52+ dst[i] = (x[i] > static_cast <T>(0 .f )) ? x[i] : sycl::expm1 (x[i]);
53+ }
54+
2455template <typename T>
2556static void gelu (const T * x, T * dst, const int k,
2657 const sycl::nd_item<3 > &item_ct1) {
@@ -335,6 +366,37 @@ static void silu_sycl(const T *x, T *dst, const int k,
335366 });
336367}
337368
369+ template <typename T>
370+ static void sgn_sycl (const T * x, T * dst, const int k, queue_ptr stream) {
371+ // hard code for now
372+ const int num_blocks = (k + 256 - 1 ) / 256 ;
373+ stream->parallel_for (
374+ sycl::nd_range<3 >((sycl::range<3 >(1 , 1 , num_blocks) * sycl::range (1 , 1 , 256 )), sycl::range (1 , 1 , 256 )), [=](sycl::nd_item<3 > item_ct1) {
375+ sgn (x, dst, k, item_ct1);
376+ });
377+ }
378+
379+ template <typename T>
380+ static void abs_sycl (const T * x, T * dst, const int k, queue_ptr stream) {
381+ // hard code for now
382+ const int num_blocks = (k + 256 - 1 ) / 256 ;
383+ stream->parallel_for (
384+ sycl::nd_range<3 >((sycl::range<3 >(1 , 1 , num_blocks) * sycl::range<3 >(1 , 1 , 256 )), sycl::range<3 >(1 , 1 , 256 )), [=](sycl::nd_item<3 > item_ct1) {
385+ abs_op (x, dst, k, item_ct1);
386+ });
387+ }
388+
389+
390+ template <typename T>
391+ static void elu_sycl (const T * x, T * dst, const int k, queue_ptr stream) {
392+ // hard code for now
393+ const int num_blocks = (k + 256 - 1 ) / 256 ;
394+ stream->parallel_for (
395+ sycl::nd_range<3 >((sycl::range<3 >(1 , 1 , num_blocks) * sycl::range<3 >(1 , 1 , 256 )), sycl::range<3 >(1 , 1 , 256 )), [=](sycl::nd_item<3 > item_ct1) {
396+ elu_op (x, dst, k, item_ct1);
397+ });
398+ }
399+
338400template <typename T>
339401static void gelu_quick_sycl (const T *x, T *dst, const int k,
340402 queue_ptr stream) {
@@ -574,6 +636,106 @@ static void clamp_sycl(const T *x, T *dst, const float min,
574636 });
575637}
576638
639+ inline void ggml_sycl_op_sgn (ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
640+ #if defined (GGML_SYCL_F16)
641+ GGML_ASSERT (dst->src [0 ]->type == GGML_TYPE_F32 || dst->src [0 ]->type == GGML_TYPE_F16);
642+ GGML_ASSERT (dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
643+
644+ #else
645+ GGML_ASSERT (dst->src [0 ]->type == GGML_TYPE_F32);
646+ GGML_ASSERT (dst->type == GGML_TYPE_F32);
647+ #endif
648+ GGML_ASSERT (dst->src [0 ]->type == dst->type );
649+ dpct::queue_ptr main_stream = ctx.stream ();
650+ SYCL_CHECK (ggml_sycl_set_device (ctx.device ));
651+ switch (dst->type ) {
652+ #if defined (GGML_SYCL_F16)
653+ case GGML_TYPE_F16:
654+ {
655+ auto data_pts = cast_data<sycl::half>(dst);
656+ sgn_sycl (data_pts.src , data_pts.dst , ggml_nelements (dst->src [0 ]), main_stream);
657+ break ;
658+ }
659+ #endif
660+ case GGML_TYPE_F32:
661+ {
662+ auto data_pts = cast_data<float >(dst);
663+ sgn_sycl (data_pts.src , data_pts.dst , ggml_nelements (dst->src [0 ]), main_stream);
664+ break ;
665+ }
666+ default :
667+ GGML_ABORT (" GGML tensor type not supported!\n " );
668+ break ;
669+ }
670+ }
671+
672+ inline void ggml_sycl_op_abs (ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
673+ #if defined (GGML_SYCL_F16)
674+ GGML_ASSERT (dst->src [0 ]->type == GGML_TYPE_F32 || dst->src [0 ]->type == GGML_TYPE_F16);
675+ GGML_ASSERT (dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
676+
677+ #else
678+ GGML_ASSERT (dst->src [0 ]->type == GGML_TYPE_F32);
679+ GGML_ASSERT (dst->type == GGML_TYPE_F32);
680+ #endif
681+ GGML_ASSERT (dst->src [0 ]->type == dst->type );
682+ dpct::queue_ptr main_stream = ctx.stream ();
683+ SYCL_CHECK (ggml_sycl_set_device (ctx.device ));
684+ switch (dst->type ) {
685+ #if defined (GGML_SYCL_F16)
686+ case GGML_TYPE_F16:
687+ {
688+ auto data_pts = cast_data<sycl::half>(dst);
689+ abs_sycl (data_pts.src , data_pts.dst , ggml_nelements (dst->src [0 ]), main_stream);
690+ break ;
691+ }
692+ #endif
693+ case GGML_TYPE_F32:
694+ {
695+ auto data_pts = cast_data<float >(dst);
696+ abs_sycl (data_pts.src , data_pts.dst , ggml_nelements (dst->src [0 ]), main_stream);
697+ break ;
698+ }
699+ default :
700+ GGML_ABORT (" GGML tensor type not supported!\n " );
701+ break ;
702+ }
703+ }
704+
705+
706+ inline void ggml_sycl_op_elu (ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
707+ #if defined (GGML_SYCL_F16)
708+ GGML_ASSERT (dst->src [0 ]->type == GGML_TYPE_F32 || dst->src [0 ]->type == GGML_TYPE_F16);
709+ GGML_ASSERT (dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
710+
711+ #else
712+ GGML_ASSERT (dst->src [0 ]->type == GGML_TYPE_F32);
713+ GGML_ASSERT (dst->type == GGML_TYPE_F32);
714+ #endif
715+ GGML_ASSERT (dst->src [0 ]->type == dst->type );
716+ dpct::queue_ptr main_stream = ctx.stream ();
717+ SYCL_CHECK (ggml_sycl_set_device (ctx.device ));
718+ switch (dst->type ) {
719+ #if defined (GGML_SYCL_F16)
720+ case GGML_TYPE_F16:
721+ {
722+ auto data_pts = cast_data<sycl::half>(dst);
723+ elu_sycl (data_pts.src , data_pts.dst , ggml_nelements (dst->src [0 ]), main_stream);
724+ break ;
725+ }
726+ #endif
727+ case GGML_TYPE_F32:
728+ {
729+ auto data_pts = cast_data<float >(dst);
730+ elu_sycl (data_pts.src , data_pts.dst , ggml_nelements (dst->src [0 ]), main_stream);
731+ break ;
732+ }
733+ default :
734+ GGML_ABORT (" GGML tensor type not supported!\n " );
735+ break ;
736+ }
737+ }
738+
577739inline void ggml_sycl_op_silu (ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
578740#if defined (GGML_SYCL_F16)
579741 GGML_ASSERT (dst->src [0 ]->type == GGML_TYPE_F32 || dst->src [0 ]->type == GGML_TYPE_F16);
@@ -1388,3 +1550,20 @@ void ggml_sycl_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
13881550 GGML_SYCL_DEBUG (" call %s done\n " , __func__);
13891551}
13901552
1553+ void ggml_sycl_sgn (ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1554+ GGML_SYCL_DEBUG (" call %s: DST Tensor type: %s\n " , __func__, ggml_type_name (dst->type ));
1555+ ggml_sycl_op_sgn (ctx, dst);
1556+ GGML_SYCL_DEBUG (" call %s done\n " , __func__);
1557+ }
1558+
1559+ void ggml_sycl_abs (ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1560+ GGML_SYCL_DEBUG (" call %s: DST Tensor type: %s\n " , __func__, ggml_type_name (dst->type ));
1561+ ggml_sycl_op_abs (ctx, dst);
1562+ GGML_SYCL_DEBUG (" call %s done\n " , __func__);
1563+ }
1564+
1565+ void ggml_sycl_elu (ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1566+ GGML_SYCL_DEBUG (" call %s: DST Tensor type: %s\n " , __func__, ggml_type_name (dst->type ));
1567+ ggml_sycl_op_elu (ctx, dst);
1568+ GGML_SYCL_DEBUG (" call %s done\n " , __func__);
1569+ }
0 commit comments