File tree Expand file tree Collapse file tree 2 files changed +12
-6
lines changed
dpctl/tensor/libtensor/include/kernels/elementwise_functions Expand file tree Collapse file tree 2 files changed +12
-6
lines changed Original file line number Diff line number Diff line change 31
31
#include < sycl/sycl.hpp>
32
32
#include < type_traits>
33
33
34
+ #include " sycl_complex.hpp"
34
35
#include " vec_size_util.hpp"
35
36
36
37
#include " kernels/dpctl_tensor_types.hpp"
@@ -56,11 +57,11 @@ using dpctl::tensor::type_utils::is_complex;
56
57
57
58
template <typename argT, typename resT> struct ImagFunctor
58
59
{
59
-
60
60
// is function constant for given argT
61
- using is_constant = typename std::false_type;
61
+ using is_constant =
62
+ typename std::is_same<is_complex<argT>, std::false_type>;
62
63
// constant value, if constant
63
- // constexpr resT constant_value = resT{};
64
+ static constexpr resT constant_value = resT{0 };
64
65
// is function defined for sycl::vec
65
66
using supports_vec = typename std::false_type;
66
67
// do both argTy and resTy support sugroup store/load operation
@@ -70,11 +71,13 @@ template <typename argT, typename resT> struct ImagFunctor
70
71
resT operator ()(const argT &in) const
71
72
{
72
73
if constexpr (is_complex<argT>::value) {
73
- return std::imag (in);
74
+ using realT = typename argT::value_type;
75
+ using sycl_complexT = typename exprm_ns::complex<realT>;
76
+ return exprm_ns::imag (sycl_complexT (in));
74
77
}
75
78
else {
76
79
static_assert (std::is_same_v<resT, argT>);
77
- return resT{ 0 } ;
80
+ return constant_value ;
78
81
}
79
82
}
80
83
};
Original file line number Diff line number Diff line change 31
31
#include < sycl/sycl.hpp>
32
32
#include < type_traits>
33
33
34
+ #include " sycl_complex.hpp"
34
35
#include " vec_size_util.hpp"
35
36
36
37
#include " kernels/dpctl_tensor_types.hpp"
@@ -70,7 +71,9 @@ template <typename argT, typename resT> struct RealFunctor
70
71
resT operator ()(const argT &in) const
71
72
{
72
73
if constexpr (is_complex<argT>::value) {
73
- return std::real (in);
74
+ using realT = typename argT::value_type;
75
+ using sycl_complexT = typename exprm_ns::complex<realT>;
76
+ return exprm_ns::real (sycl_complexT (in));
74
77
}
75
78
else {
76
79
static_assert (std::is_same_v<resT, argT>);
You can’t perform that action at this time.
0 commit comments