diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/imag.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/imag.hpp index 89adabff41..0fa432546e 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/imag.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/imag.hpp @@ -31,6 +31,7 @@ #include #include +#include "sycl_complex.hpp" #include "vec_size_util.hpp" #include "kernels/dpctl_tensor_types.hpp" @@ -53,14 +54,15 @@ using dpctl::tensor::ssize_t; namespace td_ns = dpctl::tensor::type_dispatch; using dpctl::tensor::type_utils::is_complex; +using dpctl::tensor::type_utils::is_complex_v; template struct ImagFunctor { - // is function constant for given argT - using is_constant = typename std::false_type; + using is_constant = + typename std::is_same, std::false_type>; // constant value, if constant - // constexpr resT constant_value = resT{}; + static constexpr resT constant_value = resT{0}; // is function defined for sycl::vec using supports_vec = typename std::false_type; // do both argTy and resTy support sugroup store/load operation @@ -69,12 +71,14 @@ template struct ImagFunctor resT operator()(const argT &in) const { - if constexpr (is_complex::value) { - return std::imag(in); + if constexpr (is_complex_v) { + using realT = typename argT::value_type; + using sycl_complexT = typename exprm_ns::complex; + return exprm_ns::imag(sycl_complexT(in)); } else { static_assert(std::is_same_v); - return resT{0}; + return constant_value; } } }; diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/real.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/real.hpp index bb22352907..04ed3a6e49 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/real.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/real.hpp @@ -31,6 +31,7 @@ #include #include +#include "sycl_complex.hpp" #include "vec_size_util.hpp" #include "kernels/dpctl_tensor_types.hpp" @@ -53,6 +54,7 @@ using dpctl::tensor::ssize_t; namespace td_ns = dpctl::tensor::type_dispatch; using dpctl::tensor::type_utils::is_complex; +using dpctl::tensor::type_utils::is_complex_v; template struct RealFunctor { @@ -69,8 +71,10 @@ template struct RealFunctor resT operator()(const argT &in) const { - if constexpr (is_complex::value) { - return std::real(in); + if constexpr (is_complex_v) { + using realT = typename argT::value_type; + using sycl_complexT = typename exprm_ns::complex; + return exprm_ns::real(sycl_complexT(in)); } else { static_assert(std::is_same_v);