Skip to content

Commit b6638eb

Browse files
author
Vahid Tavanashad
committed
make changes on python-side only
1 parent d4d09c6 commit b6638eb

File tree

5 files changed

+49
-24
lines changed

5 files changed

+49
-24
lines changed

dpnp/backend/extensions/ufunc/elementwise_functions/fix.cpp

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -60,15 +60,7 @@ template <typename T>
6060
struct OutputType
6161
{
6262
using value_type =
63-
typename std::disjunction<td_ns::TypeMapResultEntry<T, std::uint8_t>,
64-
td_ns::TypeMapResultEntry<T, std::uint16_t>,
65-
td_ns::TypeMapResultEntry<T, std::uint32_t>,
66-
td_ns::TypeMapResultEntry<T, std::uint64_t>,
67-
td_ns::TypeMapResultEntry<T, std::int8_t>,
68-
td_ns::TypeMapResultEntry<T, std::int16_t>,
69-
td_ns::TypeMapResultEntry<T, std::int32_t>,
70-
td_ns::TypeMapResultEntry<T, std::int64_t>,
71-
td_ns::TypeMapResultEntry<T, sycl::half>,
63+
typename std::disjunction<td_ns::TypeMapResultEntry<T, sycl::half>,
7264
td_ns::TypeMapResultEntry<T, float>,
7365
td_ns::TypeMapResultEntry<T, double>,
7466
td_ns::DefaultResultEntry<void>>::result_type;

dpnp/backend/kernels/elementwise_functions/fix.hpp

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,7 @@ struct FixFunctor
4343

4444
resT operator()(const argT &x) const
4545
{
46-
if constexpr (std::is_integral_v<argT>) {
47-
return x;
48-
}
49-
else {
50-
return (x >= 0.0) ? sycl::floor(x) : sycl::ceil(x);
51-
}
46+
return (x >= 0.0) ? sycl::floor(x) : sycl::ceil(x);
5247
}
5348
};
5449
} // namespace dpnp::kernels::fix

dpnp/dpnp_algo/dpnp_elementwise_common.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
"DPNPI0",
4242
"DPNPAngle",
4343
"DPNPBinaryFunc",
44+
"DPNPFix",
4445
"DPNPImag",
4546
"DPNPReal",
4647
"DPNPRound",
@@ -511,6 +512,45 @@ def __call__(self, x, deg=False, out=None, order="K"):
511512
return res
512513

513514

515+
class DPNPFix(DPNPUnaryFunc):
516+
"""Class that implements dpnp.real unary element-wise functions."""
517+
518+
def __init__(
519+
self,
520+
name,
521+
result_type_resolver_fn,
522+
unary_dp_impl_fn,
523+
docs,
524+
):
525+
super().__init__(
526+
name,
527+
result_type_resolver_fn,
528+
unary_dp_impl_fn,
529+
docs,
530+
)
531+
532+
def __call__(self, x, out=None, order="K"):
533+
if not dpnp.is_supported_array_type(x):
534+
pass # pass to raise error in main implementation
535+
elif dpnp.issubdtype(x.dtype, dpnp.inexact):
536+
pass # for inexact types, pass to calculate in the backend
537+
elif out is not None and (
538+
not dpnp.is_supported_array_type(out) or out.dtype != x.dtype
539+
):
540+
pass # pass to raise error in main implementation
541+
else:
542+
# for exact types, return the input
543+
if out is None:
544+
return dpnp.asarray(x, copy=True)
545+
546+
if isinstance(out, dpt.usm_ndarray):
547+
out = dpnp_array._create_from_usm_ndarray(out)
548+
out[...] = x
549+
return out
550+
551+
return super().__call__(x, out=out, order=order)
552+
553+
514554
class DPNPI0(DPNPUnaryFunc):
515555
"""Class that implements dpnp.i0 unary element-wise functions."""
516556

dpnp/dpnp_iface_mathematical.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
DPNPI0,
6464
DPNPAngle,
6565
DPNPBinaryFunc,
66+
DPNPFix,
6667
DPNPImag,
6768
DPNPReal,
6869
DPNPRound,
@@ -1755,12 +1756,6 @@ def ediff1d(ary, to_end=None, to_begin=None):
17551756
Otherwise the result is stored there and the return value `out` is
17561757
a reference to that array.
17571758
1758-
Limitations
1759-
-----------
1760-
Parameters `where` and `subok` are supported with their default values.
1761-
Keyword argument `kwargs` is currently unsupported.
1762-
Otherwise ``NotImplementedError`` exception will be raised.
1763-
17641759
See Also
17651760
--------
17661761
:obj:`dpnp.round` : Round to given number of decimals.
@@ -1781,7 +1776,7 @@ def ediff1d(ary, to_end=None, to_begin=None):
17811776
array([ 2., 2., -2., -2.])
17821777
"""
17831778

1784-
fix = DPNPUnaryFunc(
1779+
fix = DPNPFix(
17851780
"fix",
17861781
ufi._fix_result_type,
17871782
ufi._fix,
@@ -1933,8 +1928,10 @@ def ediff1d(ary, to_end=None, to_begin=None):
19331928
19341929
Notes
19351930
-----
1936-
Some spreadsheet programs calculate the "floor-towards-zero", in other words floor(-2.5) == -2.
1937-
DPNP instead uses the definition of floor where floor(-2.5) == -3.
1931+
Some spreadsheet programs calculate the "floor-towards-zero", where
1932+
``floor(-2.5) == -2``. DPNP instead uses the definition of :obj:`dpnp.floor`
1933+
where ``floor(-2.5) == -3``. The "floor-towards-zero" function is called
1934+
:obj:`dpnp.fix` in DPNP.
19381935
19391936
Examples
19401937
--------

dpnp/tests/test_mathematical.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2231,6 +2231,7 @@ def test_complex(self, func, xp, dt):
22312231
)
22322232
def test_out(self, func, dt):
22332233
a = generate_random_numpy_array(10, dt)
2234+
# TODO: use dt_out = dt when dpctl#2030 is fixed
22342235
dt_out = numpy.int8 if dt == dpnp.bool else dt
22352236
out = numpy.empty(a.shape, dtype=dt_out)
22362237
ia, iout = dpnp.array(a), dpnp.array(out)

0 commit comments

Comments
 (0)