@@ -24,8 +24,8 @@ static void acc_f32(const float * x, const float * y, float * dst, const int ne,
2424template <typename T>
2525static void gelu (const T * x, T * dst, const int k,
2626 const sycl::nd_item<3 > &item_ct1) {
27- const T GELU_COEF_A = to_T <T>(0 .044715f );
28- const T SQRT_2_OVER_PI = to_T <T>(0 .79788456080286535587989211986876f );
27+ const T GELU_COEF_A = static_cast <T>(0 .044715f );
28+ const T SQRT_2_OVER_PI = static_cast <T>(0 .79788456080286535587989211986876f );
2929 const int i = item_ct1.get_local_range (2 ) * item_ct1.get_group (2 ) +
3030 item_ct1.get_local_id (2 );
3131
@@ -34,9 +34,9 @@ static void gelu(const T * x, T * dst, const int k,
3434 }
3535
3636 float xi = x[i];
37- dst[i] = to_T <T>(0 .5f ) * xi *
38- (to_T <T>(1 .0f ) +
39- sycl::tanh (SQRT_2_OVER_PI * xi * (to_T <T>(1 .0f ) + GELU_COEF_A * xi * xi)));
37+ dst[i] = static_cast <T>(0 .5f ) * xi *
38+ (static_cast <T>(1 .0f ) +
39+ sycl::tanh (SQRT_2_OVER_PI * xi * (static_cast <T>(1 .0f ) + GELU_COEF_A * xi * xi)));
4040}
4141
4242template <typename T>
@@ -48,7 +48,7 @@ static void silu(const T * x, T * dst, const int k,
4848 if (i >= k) {
4949 return ;
5050 }
51- dst[i] = x[i] / (to_T <T>(1 .0f ) + sycl::native::exp (-x[i]));
51+ dst[i] = x[i] / (static_cast <T>(1 .0f ) + sycl::native::exp (-x[i]));
5252}
5353
5454template <typename T>
@@ -60,7 +60,7 @@ static void gelu_quick(const T *x, T *dst, int k,
6060 if (i >= k) {
6161 return ;
6262 }
63- dst[i] = x[i] * (to_T <T>(1 .0f ) / (to_T <T>(1 .0f ) + sycl::native::exp (GELU_QUICK_COEF * x[i])));
63+ dst[i] = x[i] * (static_cast <T>(1 .0f ) / (static_cast <T>(1 .0f ) + sycl::native::exp (GELU_QUICK_COEF * x[i])));
6464}
6565
6666template <typename T>
@@ -95,7 +95,7 @@ static void sigmoid(const T * x, T * dst, const int k,
9595 if (i >= k) {
9696 return ;
9797 }
98- dst[i] = 1 .0f / (to_T <T>(1 .0f ) + sycl::native::exp (-x[i]));
98+ dst[i] = 1 .0f / (static_cast <T>(1 .0f ) + sycl::native::exp (-x[i]));
9999}
100100
101101template <typename T>
@@ -143,7 +143,7 @@ static void hardsigmoid(const T * x, T * dst, const int k,
143143 if (i >= k) {
144144 return ;
145145 }
146- dst[i] = sycl::fmin (to_T <T>(1 .0f ), sycl::fmax (to_T <T>(0 .0f ), (x[i] + to_T <T>(3 .0f )) / to_T <T>(6 .0f )));
146+ dst[i] = sycl::fmin (static_cast <T>(1 .0f ), sycl::fmax (static_cast <T>(0 .0f ), (x[i] + static_cast <T>(3 .0f )) / static_cast <T>(6 .0f )));
147147}
148148
149149template <typename T>
@@ -155,7 +155,7 @@ static void hardswish(const T * x, T * dst, const int k,
155155 if (i >= k) {
156156 return ;
157157 }
158- dst[i] = x[i] * sycl::fmin (to_T <T>(1 .0f ), sycl::fmax (to_T <T>(0 .0f ), (x[i] + to_T <T>(3 .0f )) / to_T <T>(6 .0f )));
158+ dst[i] = x[i] * sycl::fmin (static_cast <T>(1 .0f ), sycl::fmax (static_cast <T>(0 .0f ), (x[i] + static_cast <T>(3 .0f )) / static_cast <T>(6 .0f )));
159159}
160160
161161template <typename T>
@@ -276,7 +276,7 @@ static void pad(const T *x, T *dst, const int ne0, const int ne00, const int ne
276276 item_ct1.get_group (0 ) * ne00 * ne01;
277277 dst[offset_dst] = x[offset_src];
278278 } else {
279- dst[offset_dst] = to_T <T>(0 .0f );
279+ dst[offset_dst] = static_cast <T>(0 .0f );
280280 }
281281}
282282
0 commit comments