Skip to content

Commit cbd2029

Browse files
committed
Add complex type handling
1 parent 2e4612b commit cbd2029

File tree

2 files changed

+158
-12
lines changed
  • dpnp/backend
    • extensions/ufunc/elementwise_functions
    • kernels/elementwise_functions

2 files changed

+158
-12
lines changed

dpnp/backend/extensions/ufunc/elementwise_functions/sinc.cpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -59,11 +59,13 @@ namespace td_ns = dpctl::tensor::type_dispatch;
5959
template <typename T>
6060
struct OutputType
6161
{
62-
using value_type =
63-
typename std::disjunction<td_ns::TypeMapResultEntry<T, sycl::half>,
64-
td_ns::TypeMapResultEntry<T, float>,
65-
td_ns::TypeMapResultEntry<T, double>,
66-
td_ns::DefaultResultEntry<void>>::result_type;
62+
using value_type = typename std::disjunction<
63+
td_ns::TypeMapResultEntry<T, sycl::half>,
64+
td_ns::TypeMapResultEntry<T, float>,
65+
td_ns::TypeMapResultEntry<T, double>,
66+
td_ns::TypeMapResultEntry<T, std::complex<float>>,
67+
td_ns::TypeMapResultEntry<T, std::complex<double>>,
68+
td_ns::DefaultResultEntry<void>>::result_type;
6769
};
6870

6971
using dpnp::kernels::sinc::SincFunctor;

dpnp/backend/kernels/elementwise_functions/sinc.hpp

Lines changed: 151 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,32 +25,176 @@
2525

2626
#pragma once
2727

28+
#define SYCL_EXT_ONEAPI_COMPLEX
29+
#if __has_include(<sycl/ext/oneapi/experimental/sycl_complex.hpp>)
30+
#include <sycl/ext/oneapi/experimental/sycl_complex.hpp>
31+
#else
32+
#include <sycl/ext/oneapi/experimental/complex/complex.hpp>
33+
#endif
34+
2835
#include <sycl/sycl.hpp>
2936

37+
// dpctl tensor headers
38+
#include "utils/type_utils.hpp"
39+
3040
namespace dpnp::kernels::sinc
3141
{
32-
template <typename argT, typename resT>
42+
namespace tu_ns = dpctl::tensor::type_utils;
43+
44+
namespace impl
45+
{
46+
namespace exprm_ns = sycl::ext::oneapi::experimental;
47+
48+
template <typename Tp>
49+
inline Tp sin(const Tp &in)
50+
{
51+
if constexpr (tu_ns::is_complex<Tp>::value) {
52+
using realTp = typename Tp::value_type;
53+
54+
constexpr realTp q_nan = std::numeric_limits<realTp>::quiet_NaN();
55+
56+
realTp const &in_re = std::real(in);
57+
realTp const &in_im = std::imag(in);
58+
59+
const bool in_re_finite = sycl::isfinite(in_re);
60+
const bool in_im_finite = sycl::isfinite(in_im);
61+
/*
62+
* Handle the nearly-non-exceptional cases where
63+
* real and imaginary parts of input are finite.
64+
*/
65+
if (in_re_finite && in_im_finite) {
66+
Tp res = exprm_ns::sin(exprm_ns::complex<realTp>(in)); // sin(in);
67+
if (in_re == realTp(0)) {
68+
res.real(sycl::copysign(realTp(0), in_re));
69+
}
70+
return res;
71+
}
72+
73+
/*
74+
* since sin(in) = -I * sinh(I * in), for special cases,
75+
* we calculate real and imaginary parts of z = sinh(I * in) and
76+
* then return { imag(z) , -real(z) } which is sin(in).
77+
*/
78+
const realTp x = -in_im;
79+
const realTp y = in_re;
80+
const bool xfinite = in_im_finite;
81+
const bool yfinite = in_re_finite;
82+
/*
83+
* sinh(+-0 +- I Inf) = sign(d(+-0, dNaN))0 + I dNaN.
84+
* The sign of 0 in the result is unspecified. Choice = normally
85+
* the same as dNaN.
86+
*
87+
* sinh(+-0 +- I NaN) = sign(d(+-0, NaN))0 + I d(NaN).
88+
* The sign of 0 in the result is unspecified. Choice = normally
89+
* the same as d(NaN).
90+
*/
91+
if (x == realTp(0) && !yfinite) {
92+
const realTp sinh_im = q_nan;
93+
const realTp sinh_re = sycl::copysign(realTp(0), x * sinh_im);
94+
return Tp{sinh_im, -sinh_re};
95+
}
96+
97+
/*
98+
* sinh(+-Inf +- I 0) = +-Inf + I +-0.
99+
*
100+
* sinh(NaN +- I 0) = d(NaN) + I +-0.
101+
*/
102+
if (y == realTp(0) && !xfinite) {
103+
if (std::isnan(x)) {
104+
const realTp sinh_re = x;
105+
const realTp sinh_im = y;
106+
return Tp{sinh_im, -sinh_re};
107+
}
108+
const realTp sinh_re = x;
109+
const realTp sinh_im = sycl::copysign(realTp(0), y);
110+
return Tp{sinh_im, -sinh_re};
111+
}
112+
113+
/*
114+
* sinh(x +- I Inf) = dNaN + I dNaN.
115+
*
116+
* sinh(x + I NaN) = d(NaN) + I d(NaN).
117+
*/
118+
if (xfinite && !yfinite) {
119+
const realTp sinh_re = q_nan;
120+
const realTp sinh_im = x * sinh_re;
121+
return Tp{sinh_im, -sinh_re};
122+
}
123+
124+
/*
125+
* sinh(+-Inf + I NaN) = +-Inf + I d(NaN).
126+
* The sign of Inf in the result is unspecified. Choice = normally
127+
* the same as d(NaN).
128+
*
129+
* sinh(+-Inf +- I Inf) = +Inf + I dNaN.
130+
* The sign of Inf in the result is unspecified.
131+
* Choice = always - here for sinh to have positive result for
132+
* imaginary part of sin.
133+
*
134+
* sinh(+-Inf + I y) = +-Inf cos(y) + I Inf sin(y)
135+
*/
136+
if (std::isinf(x)) {
137+
if (!yfinite) {
138+
const realTp sinh_re = -x * x;
139+
const realTp sinh_im = x * (y - y);
140+
return Tp{sinh_im, -sinh_re};
141+
}
142+
const realTp sinh_re = x * sycl::cos(y);
143+
const realTp sinh_im =
144+
std::numeric_limits<realTp>::infinity() * sycl::sin(y);
145+
return Tp{sinh_im, -sinh_re};
146+
}
147+
148+
/*
149+
* sinh(NaN + I NaN) = d(NaN) + I d(NaN).
150+
*
151+
* sinh(NaN +- I Inf) = d(NaN) + I d(NaN).
152+
*
153+
* sinh(NaN + I y) = d(NaN) + I d(NaN).
154+
*/
155+
const realTp y_m_y = (y - y);
156+
const realTp sinh_re = (x * x) * y_m_y;
157+
const realTp sinh_im = (x + x) * y_m_y;
158+
return Tp{sinh_im, -sinh_re};
159+
}
160+
else {
161+
if (in == Tp(0)) {
162+
return in;
163+
}
164+
return sycl::sin(in);
165+
}
166+
}
167+
} // namespace impl
168+
169+
template <typename argT, typename Tp>
33170
struct SincFunctor
34171
{
35172
// is function constant for given argT
36173
using is_constant = typename std::false_type;
37174
// constant value, if constant
38-
// constexpr resT constant_value = resT{};
175+
// constexpr Tp constant_value = Tp{};
39176
// is function defined for sycl::vec
40177
using supports_vec = typename std::false_type;
41-
// do both argT and resT support subgroup store/load operation
42-
using supports_sg_loadstore = typename std::true_type;
178+
// do both argT and Tp support subgroup store/load operation
179+
using supports_sg_loadstore = typename std::negation<
180+
std::disjunction<tu_ns::is_complex<Tp>, tu_ns::is_complex<argT>>>;
43181

44-
resT operator()(const argT &x) const
182+
Tp operator()(const argT &x) const
45183
{
46184
constexpr argT pi =
47185
static_cast<argT>(3.1415926535897932384626433832795029L);
48186
const argT y = pi * x;
49187

50188
if (y == argT(0)) {
51-
return argT(1);
189+
return Tp(1);
190+
}
191+
192+
if constexpr (tu_ns::is_complex<argT>::value) {
193+
return impl::sin(y) / y;
194+
}
195+
else {
196+
return sycl::sinpi(x) / y;
52197
}
53-
return sycl::sinpi(x) / y;
54198
}
55199
};
56200
} // namespace dpnp::kernels::sinc

0 commit comments

Comments
 (0)