File tree Expand file tree Collapse file tree 2 files changed +12
-6
lines changed
dpctl/tensor/libtensor/include/kernels/elementwise_functions Expand file tree Collapse file tree 2 files changed +12
-6
lines changed Original file line number Diff line number Diff line change 3131#include < sycl/sycl.hpp>
3232#include < type_traits>
3333
34+ #include " sycl_complex.hpp"
3435#include " vec_size_util.hpp"
3536
3637#include " kernels/dpctl_tensor_types.hpp"
@@ -56,11 +57,11 @@ using dpctl::tensor::type_utils::is_complex;
5657
5758template <typename argT, typename resT> struct ImagFunctor
5859{
59-
6060 // is function constant for given argT
61- using is_constant = typename std::false_type;
61+ using is_constant =
62+ typename std::is_same<is_complex<argT>, std::false_type>;
6263 // constant value, if constant
63- // constexpr resT constant_value = resT{};
64+ static constexpr resT constant_value = resT{0 };
6465 // is function defined for sycl::vec
6566 using supports_vec = typename std::false_type;
6667 // do both argTy and resTy support sugroup store/load operation
@@ -70,11 +71,13 @@ template <typename argT, typename resT> struct ImagFunctor
7071 resT operator ()(const argT &in) const
7172 {
7273 if constexpr (is_complex<argT>::value) {
73- return std::imag (in);
74+ using realT = typename argT::value_type;
75+ using sycl_complexT = typename exprm_ns::complex <realT>;
76+ return exprm_ns::imag (sycl_complexT (in));
7477 }
7578 else {
7679 static_assert (std::is_same_v<resT, argT>);
77- return resT{ 0 } ;
80+ return constant_value ;
7881 }
7982 }
8083};
Original file line number Diff line number Diff line change 3131#include < sycl/sycl.hpp>
3232#include < type_traits>
3333
34+ #include " sycl_complex.hpp"
3435#include " vec_size_util.hpp"
3536
3637#include " kernels/dpctl_tensor_types.hpp"
@@ -70,7 +71,9 @@ template <typename argT, typename resT> struct RealFunctor
7071 resT operator ()(const argT &in) const
7172 {
7273 if constexpr (is_complex<argT>::value) {
73- return std::real (in);
74+ using realT = typename argT::value_type;
75+ using sycl_complexT = typename exprm_ns::complex <realT>;
76+ return exprm_ns::real (sycl_complexT (in));
7477 }
7578 else {
7679 static_assert (std::is_same_v<resT, argT>);
You can’t perform that action at this time.
0 commit comments