Skip to content

Commit b088a52

Browse files
Use std::cos for complex types as well.
1 parent 863ba3b commit b088a52

File tree

1 file changed

+1
-19
lines changed
  • dpctl/tensor/libtensor/include/kernels/elementwise_functions

1 file changed

+1
-19
lines changed

dpctl/tensor/libtensor/include/kernels/elementwise_functions/cos.hpp

Lines changed: 1 addition & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -41,25 +41,7 @@ template <typename argT, typename resT> struct CosFunctor
4141

4242
resT operator()(const argT &in)
4343
{
44-
if constexpr (is_complex<argT>::value) {
45-
using realT = typename argT::value_type;
46-
// cos(x + I*y) = cos(x)*cosh(y) - I*sin(x)*sinh(y)
47-
auto v = std::real(in);
48-
realT cosX_val;
49-
const realT sinX_val = sycl::sincos(-v, &cosX_val);
50-
v = std::imag(in);
51-
const realT sinhY_val = sycl::sinh(v);
52-
const realT coshY_val = sycl::cosh(v);
53-
54-
const realT res_re = coshY_val * cosX_val;
55-
const realT res_im = sinX_val * sinhY_val;
56-
return resT{res_re, res_im};
57-
}
58-
else {
59-
static_assert(std::is_floating_point_v<argT> ||
60-
std::is_same_v<argT, sycl::half>);
61-
return std::cos(in);
62-
}
44+
return std::cos(in);
6345
}
6446
};
6547

0 commit comments

Comments
 (0)