Skip to content

Commit 0a447ad

Browse files
committed
Update implementations of real and imag
`imag` uses static constant value of 0 for real inputs and both use sycl complex extension
1 parent 8adc781 commit 0a447ad

File tree

2 files changed

+12
-6
lines changed

2 files changed

+12
-6
lines changed

dpctl/tensor/libtensor/include/kernels/elementwise_functions/imag.hpp

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
#include <sycl/sycl.hpp>
3232
#include <type_traits>
3333

34+
#include "sycl_complex.hpp"
3435
#include "vec_size_util.hpp"
3536

3637
#include "kernels/dpctl_tensor_types.hpp"
@@ -56,11 +57,11 @@ using dpctl::tensor::type_utils::is_complex;
5657

5758
template <typename argT, typename resT> struct ImagFunctor
5859
{
59-
6060
// is function constant for given argT
61-
using is_constant = typename std::false_type;
61+
using is_constant =
62+
typename std::is_same<is_complex<argT>, std::false_type>;
6263
// constant value, if constant
63-
// constexpr resT constant_value = resT{};
64+
static constexpr resT constant_value = resT{0};
6465
// is function defined for sycl::vec
6566
using supports_vec = typename std::false_type;
6667
// do both argTy and resTy support sugroup store/load operation
@@ -70,11 +71,13 @@ template <typename argT, typename resT> struct ImagFunctor
7071
resT operator()(const argT &in) const
7172
{
7273
if constexpr (is_complex<argT>::value) {
73-
return std::imag(in);
74+
using realT = typename argT::value_type;
75+
using sycl_complexT = typename exprm_ns::complex<realT>;
76+
return exprm_ns::imag(sycl_complexT(in));
7477
}
7578
else {
7679
static_assert(std::is_same_v<resT, argT>);
77-
return resT{0};
80+
return constant_value;
7881
}
7982
}
8083
};

dpctl/tensor/libtensor/include/kernels/elementwise_functions/real.hpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
#include <sycl/sycl.hpp>
3232
#include <type_traits>
3333

34+
#include "sycl_complex.hpp"
3435
#include "vec_size_util.hpp"
3536

3637
#include "kernels/dpctl_tensor_types.hpp"
@@ -70,7 +71,9 @@ template <typename argT, typename resT> struct RealFunctor
7071
resT operator()(const argT &in) const
7172
{
7273
if constexpr (is_complex<argT>::value) {
73-
return std::real(in);
74+
using realT = typename argT::value_type;
75+
using sycl_complexT = typename exprm_ns::complex<realT>;
76+
return exprm_ns::real(sycl_complexT(in));
7477
}
7578
else {
7679
static_assert(std::is_same_v<resT, argT>);

0 commit comments

Comments
 (0)