diff --git a/dpnp/backend/extensions/ufunc/CMakeLists.txt b/dpnp/backend/extensions/ufunc/CMakeLists.txt index f02030935e1e..cb85bd6213ed 100644 --- a/dpnp/backend/extensions/ufunc/CMakeLists.txt +++ b/dpnp/backend/extensions/ufunc/CMakeLists.txt @@ -34,6 +34,7 @@ set(_elementwise_sources ${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/fmod.cpp ${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/gcd.cpp ${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/heaviside.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/i0.cpp ${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/lcm.cpp ${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/ldexp.cpp ${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/logaddexp2.cpp diff --git a/dpnp/backend/extensions/ufunc/elementwise_functions/common.cpp b/dpnp/backend/extensions/ufunc/elementwise_functions/common.cpp index 4ed4cf1f5632..57a326174ea9 100644 --- a/dpnp/backend/extensions/ufunc/elementwise_functions/common.cpp +++ b/dpnp/backend/extensions/ufunc/elementwise_functions/common.cpp @@ -34,6 +34,7 @@ #include "fmod.hpp" #include "gcd.hpp" #include "heaviside.hpp" +#include "i0.hpp" #include "lcm.hpp" #include "ldexp.hpp" #include "logaddexp2.hpp" @@ -59,6 +60,7 @@ void init_elementwise_functions(py::module_ m) init_fmod(m); init_gcd(m); init_heaviside(m); + init_i0(m); init_lcm(m); init_ldexp(m); init_logaddexp2(m); diff --git a/dpnp/backend/extensions/ufunc/elementwise_functions/i0.cpp b/dpnp/backend/extensions/ufunc/elementwise_functions/i0.cpp new file mode 100644 index 000000000000..dac4abbb29fb --- /dev/null +++ b/dpnp/backend/extensions/ufunc/elementwise_functions/i0.cpp @@ -0,0 +1,124 @@ +//***************************************************************************** +// Copyright (c) 2024, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** + +#include + +#include "dpctl4pybind11.hpp" + +#include "i0.hpp" +#include "kernels/elementwise_functions/i0.hpp" +#include "populate.hpp" + +// include a local copy of elementwise common header from dpctl tensor: +// dpctl/tensor/libtensor/source/elementwise_functions/elementwise_functions.hpp +// TODO: replace by including dpctl header once available +#include "../../elementwise_functions/elementwise_functions.hpp" + +// dpctl tensor headers +#include "kernels/elementwise_functions/common.hpp" +#include "utils/type_dispatch.hpp" + +namespace dpnp::extensions::ufunc +{ +namespace py = pybind11; +namespace py_int = dpnp::extensions::py_internal; + +namespace impl +{ +namespace ew_cmn_ns = dpctl::tensor::kernels::elementwise_common; +namespace td_ns = dpctl::tensor::type_dispatch; + +/** + * @brief A factory to define pairs of supported types for which + * sycl::i0 function is available. + * + * @tparam T Type of input vector `a` and of result vector `y`. + */ +template +struct OutputType +{ + using value_type = + typename std::disjunction, + td_ns::TypeMapResultEntry, + td_ns::TypeMapResultEntry, + td_ns::DefaultResultEntry>::result_type; +}; + +using dpnp::kernels::i0::I0Functor; + +template +using ContigFunctor = ew_cmn_ns::UnaryContigFunctor, + vec_sz, + n_vecs, + enable_sg_loadstore>; + +template +using StridedFunctor = ew_cmn_ns:: + UnaryStridedFunctor>; + +using ew_cmn_ns::unary_contig_impl_fn_ptr_t; +using ew_cmn_ns::unary_strided_impl_fn_ptr_t; + +static unary_contig_impl_fn_ptr_t i0_contig_dispatch_vector[td_ns::num_types]; +static int i0_output_typeid_vector[td_ns::num_types]; +static unary_strided_impl_fn_ptr_t i0_strided_dispatch_vector[td_ns::num_types]; + +MACRO_POPULATE_DISPATCH_VECTORS(i0); +} // namespace impl + +void init_i0(py::module_ m) +{ + using arrayT = dpctl::tensor::usm_ndarray; + using event_vecT = std::vector; + { + impl::populate_i0_dispatch_vectors(); + using impl::i0_contig_dispatch_vector; + using impl::i0_output_typeid_vector; + using impl::i0_strided_dispatch_vector; + + auto i0_pyapi = [&](const arrayT &src, const arrayT &dst, + sycl::queue &exec_q, + const event_vecT &depends = {}) { + return py_int::py_unary_ufunc( + src, dst, exec_q, depends, i0_output_typeid_vector, + i0_contig_dispatch_vector, i0_strided_dispatch_vector); + }; + m.def("_i0", i0_pyapi, "", py::arg("src"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + + auto i0_result_type_pyapi = [&](const py::dtype &dtype) { + return py_int::py_unary_ufunc_result_type(dtype, + i0_output_typeid_vector); + }; + m.def("_i0_result_type", i0_result_type_pyapi); + } +} +} // namespace dpnp::extensions::ufunc diff --git a/dpnp/backend/extensions/ufunc/elementwise_functions/i0.hpp b/dpnp/backend/extensions/ufunc/elementwise_functions/i0.hpp new file mode 100644 index 000000000000..fad40f84077e --- /dev/null +++ b/dpnp/backend/extensions/ufunc/elementwise_functions/i0.hpp @@ -0,0 +1,35 @@ +//***************************************************************************** +// Copyright (c) 2024, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** + +#pragma once + +#include + +namespace py = pybind11; + +namespace dpnp::extensions::ufunc +{ +void init_i0(py::module_ m); +} // namespace dpnp::extensions::ufunc diff --git a/dpnp/backend/kernels/elementwise_functions/i0.hpp b/dpnp/backend/kernels/elementwise_functions/i0.hpp new file mode 100644 index 000000000000..c00629c9df62 --- /dev/null +++ b/dpnp/backend/kernels/elementwise_functions/i0.hpp @@ -0,0 +1,270 @@ +//***************************************************************************** +// Copyright (c) 2024, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** + +#pragma once + +#include + +/** + * Version of SYCL DPC++ 2025.1 compiler where an issue with + * sycl::ext::intel::math::cyl_bessel_i0(x) is fully resolved. + */ +#ifndef __SYCL_COMPILER_BESSEL_I0_SUPPORT +#define __SYCL_COMPILER_BESSEL_I0_SUPPORT 20241111L +#endif + +#if __SYCL_COMPILER_VERSION >= __SYCL_COMPILER_BESSEL_I0_SUPPORT +#include +#endif + +namespace dpnp::kernels::i0 +{ +/** + * The below implementation of Bessel function of order 0 + * is based on the source code from https://github.com/gcc-mirror/gcc + */ +namespace impl +{ +/** + * @brief This routine returns the cylindrical Bessel functions + * of order 0 by series expansion. + * + * @param x The argument of the Bessel function. + * @return The output Bessel function. + */ +template +inline Tp cyl_bessel_ij_0_series(const Tp x, const unsigned int max_iter) +{ + const Tp x2 = x / Tp(2); + const Tp fact = sycl::exp(-sycl::lgamma(Tp(1))); + + const Tp xx4 = x2 * x2; + Tp Jn = Tp(1); + Tp term = Tp(1); + constexpr Tp eps = std::numeric_limits::epsilon(); + + for (unsigned int i = 1; i < max_iter; ++i) { + term *= xx4 / (Tp(i) * Tp(i)); + Jn += term; + if (sycl::fabs(term / Jn) < eps) { + break; + } + } + return fact * Jn; +} + +/** + * @brief Compute the modified Bessel functions. + * + * @param x The argument of the Bessel functions. + * @return The output Bessel function. + */ +template +inline Tp bessel_ik_0(Tp x) +{ + constexpr Tp eps = std::numeric_limits::epsilon(); + constexpr Tp fp_min = Tp(10) * eps; + constexpr int max_iter = 15000; + constexpr Tp x_min = Tp(2); + + const Tp mu = Tp(0); + const Tp mu2 = mu * mu; + const Tp xi = Tp(1) / x; + const Tp xi2 = Tp(2) * xi; + Tp h = fp_min; + + Tp b = Tp(0); + Tp d = Tp(0); + Tp c = h; + int i; + for (i = 1; i <= max_iter; ++i) { + b += xi2; + d = Tp(1) / (b + d); + c = b + Tp(1) / c; + + const Tp del = c * d; + h *= del; + if (sycl::fabs(del - Tp(1)) < eps) { + break; + } + } + if (i > max_iter) { + // argument `x` is too large + return std::numeric_limits::infinity(); + } + + Tp Inul = fp_min; + const Tp Inul1 = Inul; + const Tp Ipnul = h * Inul; + + constexpr Tp pi = static_cast(3.1415926535897932384626433832795029L); + Tp f = Ipnul / Inul; + Tp Kmu, Knu1; + if (x < x_min) { + const Tp x2 = x / Tp(2); + const Tp pimu = pi * mu; + const Tp fact = + (sycl::fabs(pimu) < eps ? Tp(1) : pimu / sycl::sin(pimu)); + + Tp d = -sycl::log(x2); + Tp e = mu * d; + const Tp fact2 = (sycl::fabs(e) < eps ? Tp(1) : sycl::sinh(e) / e); + + // compute the gamma functions required by the Temme series expansions + constexpr Tp gam1 = + -static_cast(0.5772156649015328606065120900824024L); + const Tp gam2 = Tp(1) / sycl::tgamma(Tp(1)); + + Tp ff = fact * (gam1 * sycl::cosh(e) + gam2 * fact2 * d); + Tp sum = ff; + e = sycl::exp(e); + + Tp p = e / (Tp(2) * gam2); + Tp q = Tp(1) / (Tp(2) * e * gam2); + Tp c = Tp(1); + d = x2 * x2; + Tp sum1 = p; + int i; + for (i = 1; i <= max_iter; ++i) { + ff = (i * ff + p + q) / (i * i - mu2); + c *= d / i; + p /= i - mu; + q /= i + mu; + const Tp del = c * ff; + sum += del; + const Tp __del1 = c * (p - i * ff); + sum1 += __del1; + if (sycl::fabs(del) < eps * sycl::fabs(sum)) { + break; + } + } + if (i > max_iter) { + // Bessel k series failed to converge + return std::numeric_limits::quiet_NaN(); + } + Kmu = sum; + Knu1 = sum1 * xi2; + } + else { + Tp b = Tp(2) * (Tp(1) + x); + Tp d = Tp(1) / b; + Tp delh = d; + Tp h = delh; + Tp q1 = Tp(0); + Tp q2 = Tp(1); + Tp a1 = Tp(0.25L) - mu2; + Tp q = c = a1; + Tp a = -a1; + Tp s = Tp(1) + q * delh; + int i; + for (i = 2; i <= max_iter; ++i) { + a -= 2 * (i - 1); + c = -a * c / i; + const Tp qnew = (q1 - b * q2) / a; + q1 = q2; + q2 = qnew; + q += c * qnew; + b += Tp(2); + d = Tp(1) / (b + a * d); + delh = (b * d - Tp(1)) * delh; + h += delh; + const Tp dels = q * delh; + s += dels; + if (sycl::fabs(dels / s) < eps) { + break; + } + } + if (i > max_iter) { + // Steed's method failed + return std::numeric_limits::quiet_NaN(); + } + h = a1 * h; + Kmu = sycl::sqrt(pi / (Tp(2) * x)) * sycl::exp(-x) / s; + Knu1 = Kmu * (mu + x + Tp(0.5L) - h) * xi; + } + + Tp Kpmu = mu * xi * Kmu - Knu1; + Tp Inumu = xi / (f * Kmu - Kpmu); + return Inumu * Inul1 / Inul; +} + +/** + * @brief Return the regular modified Bessel function of order 0. + * + * @param x The argument of the regular modified Bessel function. + * @return The output regular modified Bessel function. + */ +template +inline Tp cyl_bessel_i0(Tp x) +{ + if (sycl::isnan(x)) { + return std::numeric_limits::quiet_NaN(); + } + + if (sycl::isinf(x)) { + // return +inf per any input infinity + return std::numeric_limits::infinity(); + } + + if (x == Tp(0)) { + return Tp(1); + } + + if (x * x < Tp(10)) { + return cyl_bessel_ij_0_series(x, 200); + } + return bessel_ik_0(sycl::fabs(x)); +} +} // namespace impl + +template +struct I0Functor +{ + // is function constant for given argT + using is_constant = typename std::false_type; + // constant value, if constant + // constexpr resT constant_value = resT{}; + // is function defined for sycl::vec + using supports_vec = typename std::false_type; + // do both argT and resT support subgroup store/load operation + using supports_sg_loadstore = typename std::true_type; + + resT operator()(const argT &x) const + { +#if __SYCL_COMPILER_VERSION >= __SYCL_COMPILER_BESSEL_I0_SUPPORT + using sycl::ext::intel::math::cyl_bessel_i0; +#else + using impl::cyl_bessel_i0; +#endif + + if constexpr (std::is_same_v) { + return static_cast(cyl_bessel_i0(float(x))); + } + else { + return cyl_bessel_i0(x); + } + } +}; +} // namespace dpnp::kernels::i0 diff --git a/dpnp/dpnp_algo/dpnp_elementwise_common.py b/dpnp/dpnp_algo/dpnp_elementwise_common.py index 7c627867b694..6cb603fb4804 100644 --- a/dpnp/dpnp_algo/dpnp_elementwise_common.py +++ b/dpnp/dpnp_algo/dpnp_elementwise_common.py @@ -38,11 +38,7 @@ from dpnp.dpnp_array import dpnp_array __all__ = [ - "acceptance_fn_gcd_lcm", - "acceptance_fn_negative", - "acceptance_fn_positive", - "acceptance_fn_sign", - "acceptance_fn_subtract", + "DPNPI0", "DPNPAngle", "DPNPBinaryFunc", "DPNPImag", @@ -50,6 +46,11 @@ "DPNPRound", "DPNPSinc", "DPNPUnaryFunc", + "acceptance_fn_gcd_lcm", + "acceptance_fn_negative", + "acceptance_fn_positive", + "acceptance_fn_sign", + "acceptance_fn_subtract", "resolve_weak_types_2nd_arg_int", ] @@ -500,6 +501,27 @@ def __call__(self, x, deg=False, out=None, order="K"): return res +class DPNPI0(DPNPUnaryFunc): + """Class that implements dpnp.i0 unary element-wise functions.""" + + def __init__( + self, + name, + result_type_resolver_fn, + unary_dp_impl_fn, + docs, + ): + super().__init__( + name, + result_type_resolver_fn, + unary_dp_impl_fn, + docs, + ) + + def __call__(self, x, out=None, order="K"): + return super().__call__(x, out=out, order=order) + + class DPNPImag(DPNPUnaryFunc): """Class that implements dpnp.imag unary element-wise functions.""" diff --git a/dpnp/dpnp_iface_mathematical.py b/dpnp/dpnp_iface_mathematical.py index d49712added5..eedf2b5d6d95 100644 --- a/dpnp/dpnp_iface_mathematical.py +++ b/dpnp/dpnp_iface_mathematical.py @@ -60,6 +60,7 @@ from .dpnp_algo import dpnp_modf from .dpnp_algo.dpnp_elementwise_common import ( + DPNPI0, DPNPAngle, DPNPBinaryFunc, DPNPImag, @@ -111,6 +112,7 @@ "gradient", "heaviside", "imag", + "i0", "lcm", "ldexp", "maximum", @@ -2519,6 +2521,50 @@ def gradient(f, *varargs, axis=None, edge_order=1): ) +_I0_DOCSTRING = """ +Modified Bessel function of the first kind, order 0. + +Usually denoted :math:`I_0`. + +For full documentation refer to :obj:`numpy.i0`. + +Parameters +---------- +x : {dpnp.ndarray, usm_ndarray} + Argument of the Bessel function, expected to have floating-point data type. +out : {None, dpnp.ndarray, usm_ndarray}, optional + Output array to populate. + Array must have the correct shape and the expected data type. + Default: ``None``. +order : {"C", "F", "A", "K"}, optional + Memory layout of the newly output array, if parameter `out` is ``None``. + Default: ``"K"``. + +Returns +------- +out : dpnp.ndarray + The modified Bessel function evaluated at each of the elements of `x`. + If the input is of either boolean or integer data type, the returned array + will have the default floating point data type of a device where `x` has + been allocated. Otherwise the returned array has the same data type. + +Examples +-------- +>>> import dpnp as np +>>> np.i0(np.array(0.0)) +array(1.) +>>> np.i0(np.array([0, 1, 2, 3])) +array([1. , 1.26606588, 2.2795853 , 4.88079259]) +""" + +i0 = DPNPI0( + "i0", + ufi._i0_result_type, + ufi._i0, + _I0_DOCSTRING, +) + + _IMAG_DOCSTRING = """ Computes imaginary part of each element `x_i` for input array `x`. diff --git a/tests/test_mathematical.py b/tests/test_mathematical.py index 422c5fab8d15..20db6a554f7c 100644 --- a/tests/test_mathematical.py +++ b/tests/test_mathematical.py @@ -1355,6 +1355,73 @@ def test_op_multiple_dtypes(dtype1, func, dtype2, data): assert_allclose(result, expected) +class TestI0: + def test_0d(self): + a = dpnp.array(0.5) + na = a.asnumpy() + assert_dtype_allclose(dpnp.i0(a), numpy.i0(na)) + + @pytest.mark.parametrize( + "dt", get_all_dtypes(no_bool=True, no_none=True, no_complex=True) + ) + def test_1d(self, dt): + a = numpy.array( + [0.49842636, 0.6969809, 0.22011976, 0.0155549, 10.0], dtype=dt + ) + ia = dpnp.array(a) + + result = dpnp.i0(ia) + expected = numpy.i0(a) + assert_dtype_allclose(result, expected) + + @pytest.mark.parametrize("dt", get_float_dtypes()) + def test_2d(self, dt): + a = numpy.array( + [ + [0.827002, 0.99959078], + [0.89694769, 0.39298162], + [0.37954418, 0.05206293], + [0.36465447, 0.72446427], + [0.48164949, 0.50324519], + ], + dtype=dt, + ) + ia = dpnp.array(a) + + result = dpnp.i0(ia) + expected = numpy.i0(a) + assert_dtype_allclose(result, expected) + + @pytest.mark.filterwarnings("ignore::RuntimeWarning") + def test_nan(self): + a = numpy.array(numpy.nan) + ia = dpnp.array(a) + + result = dpnp.i0(ia) + expected = numpy.i0(a) + assert_equal(result, expected) + + # numpy.i0(numpy.inf) returns NaN, but expected Inf + @pytest.mark.parametrize("dt", get_float_dtypes()) + def test_infs(self, dt): + a = dpnp.array([dpnp.inf, -dpnp.inf], dtype=dt) + assert (dpnp.i0(a) == dpnp.inf).all() + + # dpnp.i0 returns float16, but numpy.i0 returns float64 + def test_bool(self): + a = numpy.array([False, True, False]) + ia = dpnp.array(a) + + result = dpnp.i0(ia) + expected = numpy.i0(a) + assert_dtype_allclose(result, expected, check_only_type_kind=True) + + @pytest.mark.parametrize("xp", [dpnp, numpy]) + def test_complex(self, xp): + a = xp.array([0, 1 + 2j]) + assert_raises((ValueError, TypeError), xp.i0, a) + + class TestLdexp: @pytest.mark.parametrize("mant_dt", get_float_dtypes()) @pytest.mark.parametrize("exp_dt", get_integer_dtypes()) diff --git a/tests/test_sycl_queue.py b/tests/test_sycl_queue.py index 5e638e1968af..979b6f0d7951 100644 --- a/tests/test_sycl_queue.py +++ b/tests/test_sycl_queue.py @@ -479,6 +479,7 @@ def test_meshgrid(device): pytest.param("floor", [-1.7, -1.5, -0.2, 0.2, 1.5, 1.7, 2.0]), pytest.param("gradient", [1.0, 2.0, 4.0, 7.0, 11.0, 16.0]), pytest.param("histogram_bin_edges", [0, 0, 0, 1, 2, 3, 3, 4, 5]), + pytest.param("i0", [0, 1, 2, 3]), pytest.param( "imag", [complex(1.0, 2.0), complex(3.0, 4.0), complex(5.0, 6.0)] ), diff --git a/tests/test_usm_type.py b/tests/test_usm_type.py index ef11623d5f3c..07b8bb4ef174 100644 --- a/tests/test_usm_type.py +++ b/tests/test_usm_type.py @@ -600,6 +600,7 @@ def test_norm(usm_type, ord, axis): pytest.param("floor", [-1.7, -1.5, -0.2, 0.2, 1.5, 1.7, 2.0]), pytest.param("gradient", [1, 2, 4, 7, 11, 16]), pytest.param("histogram_bin_edges", [0, 0, 0, 1, 2, 3, 3, 4, 5]), + pytest.param("i0", [0, 1, 2, 3]), pytest.param( "imag", [complex(1.0, 2.0), complex(3.0, 4.0), complex(5.0, 6.0)] ), diff --git a/tests/third_party/cupy/math_tests/test_special.py b/tests/third_party/cupy/math_tests/test_special.py index 5b3b84c29a5e..7699b96fa567 100644 --- a/tests/third_party/cupy/math_tests/test_special.py +++ b/tests/third_party/cupy/math_tests/test_special.py @@ -7,8 +7,6 @@ class TestSpecial(unittest.TestCase): - - @pytest.mark.skip("i0 is not implemented") @testing.for_dtypes(["e", "f", "d"]) @testing.numpy_cupy_allclose(rtol=1e-3) def test_i0(self, xp, dtype):