@@ -84,6 +84,14 @@ static void gelu_quick(const T *x, T *dst, int k,
8484 dst[i] = x[i] * (static_cast <T>(1 .0f ) / (static_cast <T>(1 .0f ) + sycl::native::exp (GELU_QUICK_COEF * x[i])));
8585}
8686
87+ template <typename T>
88+ static void gelu_erf (const T * x, T * dst, const int k, const sycl::nd_item<3 > &item_ct1) {
89+ const T SQRT_2_INV = static_cast <T>(0 .70710678118654752440084436210484f );
90+ for (auto i = item_ct1.get_global_id (2 ); i < (const size_t )k; i += item_ct1.get_global_range (2 )) {
91+ dst[i] = static_cast <T>(0 .5f )*x[i]*(static_cast <T>(1 .0f ) + sycl::erf (x[i]*SQRT_2_INV));
92+ }
93+ }
94+
8795template <typename T>
8896static void tanh (const T *x, T *dst, int k,
8997 const sycl::nd_item<3 > &item_ct1) {
@@ -400,6 +408,20 @@ static void gelu_quick_sycl(const T *x, T *dst, const int k,
400408 });
401409}
402410
411+
412+ template <typename T>
413+ static void gelu_erf_sycl (const T *x, T *dst, const int k,
414+ queue_ptr stream) {
415+ const int num_blocks = ceil_div (k, SYCL_GELU_BLOCK_SIZE);
416+ stream->parallel_for (
417+ sycl::nd_range<3 >(sycl::range<3 >(1 , 1 , num_blocks) *
418+ sycl::range<3 >(1 , 1 , SYCL_GELU_BLOCK_SIZE),
419+ sycl::range<3 >(1 , 1 , SYCL_GELU_BLOCK_SIZE)),
420+ [=](sycl::nd_item<3 > item_ct1) {
421+ gelu_erf (x, dst, k, item_ct1);
422+ });
423+ }
424+
403425template <typename T>
404426static void tanh_sycl (const T *x, T *dst, const int k,
405427 queue_ptr stream) {
@@ -816,6 +838,38 @@ inline void ggml_sycl_op_gelu_quick(ggml_backend_sycl_context & ctx, ggml_tensor
816838 }
817839}
818840
841+ inline void ggml_sycl_op_gelu_erf (ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
842+ #if defined (GGML_SYCL_F16)
843+ GGML_ASSERT (dst->src [0 ]->type == GGML_TYPE_F32 || dst->src [0 ]->type == GGML_TYPE_F16);
844+ GGML_ASSERT (dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
845+ #else
846+ GGML_ASSERT (dst->src [0 ]->type == GGML_TYPE_F32);
847+ GGML_ASSERT (dst->type == GGML_TYPE_F32);
848+ #endif
849+ GGML_ASSERT (dst->src [0 ]->type == dst->type );
850+ dpct::queue_ptr main_stream = ctx.stream ();
851+ SYCL_CHECK (ggml_sycl_set_device (ctx.device ));
852+ switch (dst->type ) {
853+ #if defined (GGML_SYCL_F16)
854+ case GGML_TYPE_F16:
855+ {
856+ auto data_pts = cast_data<sycl::half>(dst);
857+ gelu_erf_sycl (data_pts.src , data_pts.dst , ggml_nelements (dst->src [0 ]), main_stream);
858+ break ;
859+ }
860+ #endif
861+ case GGML_TYPE_F32:
862+ {
863+ auto data_pts = cast_data<float >(dst);
864+ gelu_erf_sycl (data_pts.src , data_pts.dst , ggml_nelements (dst->src [0 ]), main_stream);
865+ break ;
866+ }
867+ default :
868+ GGML_ABORT (" GGML tensor type not supported!\n " );
869+ }
870+ }
871+
872+
819873inline void ggml_sycl_op_tanh (ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
820874#if defined (GGML_SYCL_F16)
821875 GGML_ASSERT (dst->src [0 ]->type == GGML_TYPE_F32 || dst->src [0 ]->type == GGML_TYPE_F16);
@@ -1432,6 +1486,12 @@ void ggml_sycl_gelu_quick(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
14321486 GGML_SYCL_DEBUG (" call %s done\n " , __func__);
14331487}
14341488
1489+ void ggml_sycl_gelu_erf (ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1490+ GGML_SYCL_DEBUG (" call %s: DST Tensor type: %s\n " , __func__, ggml_type_name (dst->type ));
1491+ ggml_sycl_op_gelu_erf (ctx, dst);
1492+ GGML_SYCL_DEBUG (" call %s done\n " , __func__);
1493+ }
1494+
14351495void ggml_sycl_tanh (ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
14361496 GGML_SYCL_DEBUG (" call %s: DST Tensor type: %s\n " , __func__, ggml_type_name (dst->type ));
14371497 ggml_sycl_op_tanh (ctx, dst);
0 commit comments