@@ -21,6 +21,27 @@ static void acc_f32(const float * x, const float * y, float * dst, const int ne,
2121 }
2222}
2323
24+ template <typename T>
25+ static void sgn (const T * x, T * dst, const int k, const sycl::nd_item<3 > &item_ct1) {
26+ for (auto i = item_ct1.get_global_id (2 ); i < (const size_t )k; i += item_ct1.get_global_range (2 )) {
27+ dst[i] = x[i] > static_cast <T>(0 .f ) ? static_cast <T>(1 .f ) : ((x[i] < static_cast <T>(0 .f ) ? static_cast <T>(-1 .f ) : static_cast <T>(0 .f )));
28+ }
29+ }
30+
31+ template <typename T>
32+ static void abs_op (const T * x, T * dst, const int k, const sycl::nd_item<3 > &item_ct1) {
33+ for (auto i = item_ct1.get_global_id (2 ); i < (const size_t )k; i += item_ct1.get_global_range (2 )) {
34+ dst[i] = sycl::fabs (x[i]);
35+ }
36+ }
37+
38+ template <typename T>
39+ static void elu_op (const T * x, T * dst, const int k, const sycl::nd_item<3 > &item_ct1) {
40+ for (auto i = item_ct1.get_global_id (2 ); i < (const size_t )k; i += item_ct1.get_global_range (2 )) {
41+ dst[i] = (x[i] > static_cast <T>(0 .f )) ? x[i] : sycl::expm1 (x[i]);
42+ }
43+ }
44+
2445template <typename T>
2546static void gelu (const T * x, T * dst, const int k,
2647 const sycl::nd_item<3 > &item_ct1) {
@@ -335,6 +356,37 @@ static void silu_sycl(const T *x, T *dst, const int k,
335356 });
336357}
337358
359+ template <typename T>
360+ static void sgn_sycl (const T * x, T * dst, const int k, queue_ptr stream) {
361+ // hard code for now
362+ const int num_blocks = ceil_div (k, 256 );
363+ stream->parallel_for (
364+ sycl::nd_range<3 >((sycl::range<3 >(1 , 1 , num_blocks) * sycl::range (1 , 1 , 256 )), sycl::range (1 , 1 , 256 )), [=](sycl::nd_item<3 > item_ct1) {
365+ sgn (x, dst, k, item_ct1);
366+ });
367+ }
368+
369+ template <typename T>
370+ static void abs_sycl (const T * x, T * dst, const int k, queue_ptr stream) {
371+ // hard code for now
372+ const int num_blocks = ceil_div (k, 256 );
373+ stream->parallel_for (
374+ sycl::nd_range<3 >((sycl::range<3 >(1 , 1 , num_blocks) * sycl::range<3 >(1 , 1 , 256 )), sycl::range<3 >(1 , 1 , 256 )), [=](sycl::nd_item<3 > item_ct1) {
375+ abs_op (x, dst, k, item_ct1);
376+ });
377+ }
378+
379+
380+ template <typename T>
381+ static void elu_sycl (const T * x, T * dst, const int k, queue_ptr stream) {
382+ // hard code for now
383+ const int num_blocks = ceil_div (k, 256 );
384+ stream->parallel_for (
385+ sycl::nd_range<3 >((sycl::range<3 >(1 , 1 , num_blocks) * sycl::range<3 >(1 , 1 , 256 )), sycl::range<3 >(1 , 1 , 256 )), [=](sycl::nd_item<3 > item_ct1) {
386+ elu_op (x, dst, k, item_ct1);
387+ });
388+ }
389+
338390template <typename T>
339391static void gelu_quick_sycl (const T *x, T *dst, const int k,
340392 queue_ptr stream) {
@@ -574,6 +626,106 @@ static void clamp_sycl(const T *x, T *dst, const float min,
574626 });
575627}
576628
629+ inline void ggml_sycl_op_sgn (ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
630+ #if defined (GGML_SYCL_F16)
631+ GGML_ASSERT (dst->src [0 ]->type == GGML_TYPE_F32 || dst->src [0 ]->type == GGML_TYPE_F16);
632+ GGML_ASSERT (dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
633+
634+ #else
635+ GGML_ASSERT (dst->src [0 ]->type == GGML_TYPE_F32);
636+ GGML_ASSERT (dst->type == GGML_TYPE_F32);
637+ #endif
638+ GGML_ASSERT (dst->src [0 ]->type == dst->type );
639+ dpct::queue_ptr main_stream = ctx.stream ();
640+ SYCL_CHECK (ggml_sycl_set_device (ctx.device ));
641+ switch (dst->type ) {
642+ #if defined (GGML_SYCL_F16)
643+ case GGML_TYPE_F16:
644+ {
645+ auto data_pts = cast_data<sycl::half>(dst);
646+ sgn_sycl (data_pts.src , data_pts.dst , ggml_nelements (dst->src [0 ]), main_stream);
647+ break ;
648+ }
649+ #endif
650+ case GGML_TYPE_F32:
651+ {
652+ auto data_pts = cast_data<float >(dst);
653+ sgn_sycl (data_pts.src , data_pts.dst , ggml_nelements (dst->src [0 ]), main_stream);
654+ break ;
655+ }
656+ default :
657+ GGML_ABORT (" GGML tensor type not supported!\n " );
658+ break ;
659+ }
660+ }
661+
662+ inline void ggml_sycl_op_abs (ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
663+ #if defined (GGML_SYCL_F16)
664+ GGML_ASSERT (dst->src [0 ]->type == GGML_TYPE_F32 || dst->src [0 ]->type == GGML_TYPE_F16);
665+ GGML_ASSERT (dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
666+
667+ #else
668+ GGML_ASSERT (dst->src [0 ]->type == GGML_TYPE_F32);
669+ GGML_ASSERT (dst->type == GGML_TYPE_F32);
670+ #endif
671+ GGML_ASSERT (dst->src [0 ]->type == dst->type );
672+ dpct::queue_ptr main_stream = ctx.stream ();
673+ SYCL_CHECK (ggml_sycl_set_device (ctx.device ));
674+ switch (dst->type ) {
675+ #if defined (GGML_SYCL_F16)
676+ case GGML_TYPE_F16:
677+ {
678+ auto data_pts = cast_data<sycl::half>(dst);
679+ abs_sycl (data_pts.src , data_pts.dst , ggml_nelements (dst->src [0 ]), main_stream);
680+ break ;
681+ }
682+ #endif
683+ case GGML_TYPE_F32:
684+ {
685+ auto data_pts = cast_data<float >(dst);
686+ abs_sycl (data_pts.src , data_pts.dst , ggml_nelements (dst->src [0 ]), main_stream);
687+ break ;
688+ }
689+ default :
690+ GGML_ABORT (" GGML tensor type not supported!\n " );
691+ break ;
692+ }
693+ }
694+
695+
696+ inline void ggml_sycl_op_elu (ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
697+ #if defined (GGML_SYCL_F16)
698+ GGML_ASSERT (dst->src [0 ]->type == GGML_TYPE_F32 || dst->src [0 ]->type == GGML_TYPE_F16);
699+ GGML_ASSERT (dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
700+
701+ #else
702+ GGML_ASSERT (dst->src [0 ]->type == GGML_TYPE_F32);
703+ GGML_ASSERT (dst->type == GGML_TYPE_F32);
704+ #endif
705+ GGML_ASSERT (dst->src [0 ]->type == dst->type );
706+ dpct::queue_ptr main_stream = ctx.stream ();
707+ SYCL_CHECK (ggml_sycl_set_device (ctx.device ));
708+ switch (dst->type ) {
709+ #if defined (GGML_SYCL_F16)
710+ case GGML_TYPE_F16:
711+ {
712+ auto data_pts = cast_data<sycl::half>(dst);
713+ elu_sycl (data_pts.src , data_pts.dst , ggml_nelements (dst->src [0 ]), main_stream);
714+ break ;
715+ }
716+ #endif
717+ case GGML_TYPE_F32:
718+ {
719+ auto data_pts = cast_data<float >(dst);
720+ elu_sycl (data_pts.src , data_pts.dst , ggml_nelements (dst->src [0 ]), main_stream);
721+ break ;
722+ }
723+ default :
724+ GGML_ABORT (" GGML tensor type not supported!\n " );
725+ break ;
726+ }
727+ }
728+
577729inline void ggml_sycl_op_silu (ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
578730#if defined (GGML_SYCL_F16)
579731 GGML_ASSERT (dst->src [0 ]->type == GGML_TYPE_F32 || dst->src [0 ]->type == GGML_TYPE_F16);
@@ -1388,3 +1540,20 @@ void ggml_sycl_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
13881540 GGML_SYCL_DEBUG (" call %s done\n " , __func__);
13891541}
13901542
1543+ void ggml_sycl_sgn (ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1544+ GGML_SYCL_DEBUG (" call %s: DST Tensor type: %s\n " , __func__, ggml_type_name (dst->type ));
1545+ ggml_sycl_op_sgn (ctx, dst);
1546+ GGML_SYCL_DEBUG (" call %s done\n " , __func__);
1547+ }
1548+
1549+ void ggml_sycl_abs (ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1550+ GGML_SYCL_DEBUG (" call %s: DST Tensor type: %s\n " , __func__, ggml_type_name (dst->type ));
1551+ ggml_sycl_op_abs (ctx, dst);
1552+ GGML_SYCL_DEBUG (" call %s done\n " , __func__);
1553+ }
1554+
1555+ void ggml_sycl_elu (ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1556+ GGML_SYCL_DEBUG (" call %s: DST Tensor type: %s\n " , __func__, ggml_type_name (dst->type ));
1557+ ggml_sycl_op_elu (ctx, dst);
1558+ GGML_SYCL_DEBUG (" call %s done\n " , __func__);
1559+ }
0 commit comments