@@ -84,6 +84,15 @@ static void gelu_quick(const T *x, T *dst, int k,
8484 dst[i] = x[i] * (static_cast <T>(1 .0f ) / (static_cast <T>(1 .0f ) + sycl::native::exp (GELU_QUICK_COEF * x[i])));
8585}
8686
87+ template <typename T>
88+ static void gelu_erf (const T * x, T * dst, const int k, const sycl::nd_item<3 > &item_ct1) {
89+ const T SQRT_2_INV = static_cast <T>(0 .70710678118654752440084436210484f );
90+ for (auto i = item_ct1.get_global_id (2 ); i < (const size_t )k; i += item_ct1.get_global_range (2 )) {
91+ auto x_i = x[i];
92+ dst[i] = static_cast <T>(0 .5f ) * x_i * (static_cast <T>(1 .0f ) + sycl::erf (x_i * SQRT_2_INV));
93+ }
94+ }
95+
8796template <typename T>
8897static void tanh (const T *x, T *dst, int k,
8998 const sycl::nd_item<3 > &item_ct1) {
@@ -400,6 +409,20 @@ static void gelu_quick_sycl(const T *x, T *dst, const int k,
400409 });
401410}
402411
412+
413+ template <typename T>
414+ static void gelu_erf_sycl (const T *x, T *dst, const int k,
415+ queue_ptr stream) {
416+ const int num_blocks = ceil_div (k, SYCL_GELU_BLOCK_SIZE);
417+ stream->parallel_for (
418+ sycl::nd_range<3 >(sycl::range<3 >(1 , 1 , num_blocks) *
419+ sycl::range<3 >(1 , 1 , SYCL_GELU_BLOCK_SIZE),
420+ sycl::range<3 >(1 , 1 , SYCL_GELU_BLOCK_SIZE)),
421+ [=](sycl::nd_item<3 > item_ct1) {
422+ gelu_erf (x, dst, k, item_ct1);
423+ });
424+ }
425+
403426template <typename T>
404427static void tanh_sycl (const T *x, T *dst, const int k,
405428 queue_ptr stream) {
@@ -816,6 +839,38 @@ inline void ggml_sycl_op_gelu_quick(ggml_backend_sycl_context & ctx, ggml_tensor
816839 }
817840}
818841
842+ inline void ggml_sycl_op_gelu_erf (ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
843+ #if defined (GGML_SYCL_F16)
844+ GGML_ASSERT (dst->src [0 ]->type == GGML_TYPE_F32 || dst->src [0 ]->type == GGML_TYPE_F16);
845+ GGML_ASSERT (dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
846+ #else
847+ GGML_ASSERT (dst->src [0 ]->type == GGML_TYPE_F32);
848+ GGML_ASSERT (dst->type == GGML_TYPE_F32);
849+ #endif
850+ GGML_ASSERT (dst->src [0 ]->type == dst->type );
851+ dpct::queue_ptr main_stream = ctx.stream ();
852+ SYCL_CHECK (ggml_sycl_set_device (ctx.device ));
853+ switch (dst->type ) {
854+ #if defined (GGML_SYCL_F16)
855+ case GGML_TYPE_F16:
856+ {
857+ auto data_pts = cast_data<sycl::half>(dst);
858+ gelu_erf_sycl (data_pts.src , data_pts.dst , ggml_nelements (dst->src [0 ]), main_stream);
859+ break ;
860+ }
861+ #endif
862+ case GGML_TYPE_F32:
863+ {
864+ auto data_pts = cast_data<float >(dst);
865+ gelu_erf_sycl (data_pts.src , data_pts.dst , ggml_nelements (dst->src [0 ]), main_stream);
866+ break ;
867+ }
868+ default :
869+ GGML_ABORT (" GGML tensor type not supported!\n " );
870+ }
871+ }
872+
873+
819874inline void ggml_sycl_op_tanh (ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
820875#if defined (GGML_SYCL_F16)
821876 GGML_ASSERT (dst->src [0 ]->type == GGML_TYPE_F32 || dst->src [0 ]->type == GGML_TYPE_F16);
@@ -1425,6 +1480,11 @@ void ggml_sycl_gelu_quick(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
14251480 ggml_sycl_op_gelu_quick (ctx, dst);
14261481}
14271482
1483+ void ggml_sycl_gelu_erf (ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1484+ scope_op_debug_print scope_dbg_print (__func__, dst, /* num_src=*/ 1 );
1485+ ggml_sycl_op_gelu_erf (ctx, dst);
1486+ }
1487+
14281488void ggml_sycl_tanh (ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
14291489 scope_op_debug_print scope_dbg_print (__func__, dst, /* num_src=*/ 1 );
14301490 ggml_sycl_op_tanh (ctx, dst);
0 commit comments