diff --git a/dpnp/backend/extensions/ufunc/CMakeLists.txt b/dpnp/backend/extensions/ufunc/CMakeLists.txt index e14e053f369f..4fe714dbeb5b 100644 --- a/dpnp/backend/extensions/ufunc/CMakeLists.txt +++ b/dpnp/backend/extensions/ufunc/CMakeLists.txt @@ -35,6 +35,7 @@ set(_elementwise_sources ${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/gcd.cpp ${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/heaviside.cpp ${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/lcm.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/ldexp.cpp ${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/logaddexp2.cpp ${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/radians.cpp ) diff --git a/dpnp/backend/extensions/ufunc/elementwise_functions/common.cpp b/dpnp/backend/extensions/ufunc/elementwise_functions/common.cpp index 1b348b6dc0e6..c5b931d3d21d 100644 --- a/dpnp/backend/extensions/ufunc/elementwise_functions/common.cpp +++ b/dpnp/backend/extensions/ufunc/elementwise_functions/common.cpp @@ -35,6 +35,7 @@ #include "gcd.hpp" #include "heaviside.hpp" #include "lcm.hpp" +#include "ldexp.hpp" #include "logaddexp2.hpp" #include "radians.hpp" @@ -57,6 +58,7 @@ void init_elementwise_functions(py::module_ m) init_gcd(m); init_heaviside(m); init_lcm(m); + init_ldexp(m); init_logaddexp2(m); init_radians(m); } diff --git a/dpnp/backend/extensions/ufunc/elementwise_functions/ldexp.cpp b/dpnp/backend/extensions/ufunc/elementwise_functions/ldexp.cpp new file mode 100644 index 000000000000..c9cdd7edb5e5 --- /dev/null +++ b/dpnp/backend/extensions/ufunc/elementwise_functions/ldexp.cpp @@ -0,0 +1,169 @@ +//***************************************************************************** +// Copyright (c) 2024, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// maxification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** + +#include + +#include "dpctl4pybind11.hpp" + +#include "kernels/elementwise_functions/ldexp.hpp" +#include "ldexp.hpp" +#include "populate.hpp" + +// include a local copy of elementwise common header from dpctl tensor: +// dpctl/tensor/libtensor/source/elementwise_functions/elementwise_functions.hpp +// TODO: replace by including dpctl header once available +#include "../../elementwise_functions/elementwise_functions.hpp" + +// dpctl tensor headers +#include "kernels/elementwise_functions/common.hpp" +#include "kernels/elementwise_functions/maximum.hpp" +#include "utils/type_dispatch.hpp" + +namespace dpnp::extensions::ufunc +{ +namespace py = pybind11; +namespace py_int = dpnp::extensions::py_internal; +namespace td_ns = dpctl::tensor::type_dispatch; + +namespace impl +{ +namespace ew_cmn_ns = dpctl::tensor::kernels::elementwise_common; +namespace max_ns = dpctl::tensor::kernels::maximum; + +// Supports the same types table as for maximum function in dpctl +// template +// using OutputType = max_ns::MaximumOutputType; +template +struct OutputType +{ + using value_type = typename std::disjunction< // disjunction is C++17 + // feature, supported by DPC++ + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::DefaultResultEntry>::result_type; +}; + +using dpnp::kernels::ldexp::LdexpFunctor; + +template +using ContigFunctor = + ew_cmn_ns::BinaryContigFunctor, + vec_sz, + n_vecs, + enable_sg_loadstore>; + +template +using StridedFunctor = + ew_cmn_ns::BinaryStridedFunctor>; + +using ew_cmn_ns::binary_contig_impl_fn_ptr_t; +using ew_cmn_ns::binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t; +using ew_cmn_ns::binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t; +using ew_cmn_ns::binary_strided_impl_fn_ptr_t; + +static binary_contig_impl_fn_ptr_t + ldexp_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; +static int ldexp_output_typeid_table[td_ns::num_types][td_ns::num_types]; +static binary_strided_impl_fn_ptr_t + ldexp_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; + +MACRO_POPULATE_DISPATCH_TABLES(ldexp); +} // namespace impl + +void init_ldexp(py::module_ m) +{ + using arrayT = dpctl::tensor::usm_ndarray; + using event_vecT = std::vector; + { + impl::populate_ldexp_dispatch_tables(); + using impl::ldexp_contig_dispatch_table; + using impl::ldexp_output_typeid_table; + using impl::ldexp_strided_dispatch_table; + + auto ldexp_pyapi = [&](const arrayT &src1, const arrayT &src2, + const arrayT &dst, sycl::queue &exec_q, + const event_vecT &depends = {}) { + return py_int::py_binary_ufunc( + src1, src2, dst, exec_q, depends, ldexp_output_typeid_table, + ldexp_contig_dispatch_table, ldexp_strided_dispatch_table, + // no support of C-contig row with broadcasting in OneMKL + td_ns::NullPtrTable< + impl:: + binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t>{}, + td_ns::NullPtrTable< + impl:: + binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t>{}); + }; + m.def("_ldexp", ldexp_pyapi, "", py::arg("src1"), py::arg("src2"), + py::arg("dst"), py::arg("sycl_queue"), + py::arg("depends") = py::list()); + + auto ldexp_result_type_pyapi = [&](const py::dtype &dtype1, + const py::dtype &dtype2) { + return py_int::py_binary_ufunc_result_type( + dtype1, dtype2, ldexp_output_typeid_table); + }; + m.def("_ldexp_result_type", ldexp_result_type_pyapi); + } +} +} // namespace dpnp::extensions::ufunc diff --git a/dpnp/backend/extensions/ufunc/elementwise_functions/ldexp.hpp b/dpnp/backend/extensions/ufunc/elementwise_functions/ldexp.hpp new file mode 100644 index 000000000000..1ab4ff10f87d --- /dev/null +++ b/dpnp/backend/extensions/ufunc/elementwise_functions/ldexp.hpp @@ -0,0 +1,35 @@ +//***************************************************************************** +// Copyright (c) 2024, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** + +#pragma once + +#include + +namespace py = pybind11; + +namespace dpnp::extensions::ufunc +{ +void init_ldexp(py::module_ m); +} // namespace dpnp::extensions::ufunc diff --git a/dpnp/backend/kernels/elementwise_functions/ldexp.hpp b/dpnp/backend/kernels/elementwise_functions/ldexp.hpp new file mode 100644 index 000000000000..1755d729f29a --- /dev/null +++ b/dpnp/backend/kernels/elementwise_functions/ldexp.hpp @@ -0,0 +1,55 @@ +//***************************************************************************** +// Copyright (c) 2024, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** + +#pragma once + +#include + +// dpctl tensor headers +#include "utils/math_utils.hpp" +#include "utils/type_utils.hpp" + +namespace dpnp::kernels::ldexp +{ +template +struct LdexpFunctor +{ + using supports_sg_loadstore = typename std::true_type; + using supports_vec = typename std::false_type; + + resT operator()(const argT1 &in1, const argT2 &in2) const + { + if (((int)in2) == in2) { + return sycl::ldexp(in1, in2); + } + + // a separate handling for large integer values + if (in2 > 0) { + return std::numeric_limits::infinity(); + } + return resT(0); + } +}; +} // namespace dpnp::kernels::ldexp diff --git a/dpnp/dpnp_algo/dpnp_elementwise_common.py b/dpnp/dpnp_algo/dpnp_elementwise_common.py index 0793306a1b44..0f2b383751ff 100644 --- a/dpnp/dpnp_algo/dpnp_elementwise_common.py +++ b/dpnp/dpnp_algo/dpnp_elementwise_common.py @@ -25,6 +25,8 @@ # ***************************************************************************** import dpctl.tensor as dpt +import dpctl.tensor._tensor_impl as dti +import dpctl.tensor._type_utils as dtu import numpy from dpctl.tensor._elementwise_common import ( BinaryElementwiseFunc, @@ -43,9 +45,11 @@ "acceptance_fn_subtract", "DPNPAngle", "DPNPBinaryFunc", + "DPNPImag", "DPNPReal", "DPNPRound", "DPNPUnaryFunc", + "resolve_weak_types_2nd_arg_int", ] @@ -244,6 +248,14 @@ class DPNPBinaryFunc(BinaryElementwiseFunc): The function is only called when both arguments of the binary function require casting, e.g. both arguments of `dpctl.tensor.logaddexp` are arrays with integral data type. + weak_type_resolver : {callable}, optional + Function to influence type promotion behavior for Python scalar types + of this binary function. The function takes 3 arguments: + o1_dtype - Data type or Python scalar type of the first argument + o2_dtype - Data type or Python scalar type of of the second argument + sycl_dev - The :class:`dpctl.SyclDevice` where the function + evaluation is carried out. + One of `o1_dtype` and `o2_dtype` must be a ``dtype`` instance. """ def __init__( @@ -256,6 +268,7 @@ def __init__( mkl_impl_fn=None, binary_inplace_fn=None, acceptance_fn=None, + weak_type_resolver=None, ): def _call_func(src1, src2, dst, sycl_queue, depends=None): """ @@ -281,6 +294,7 @@ def _call_func(src1, src2, dst, sycl_queue, depends=None): docs, binary_inplace_fn, acceptance_fn=acceptance_fn, + weak_type_resolver=weak_type_resolver, ) self.__name__ = "DPNPBinaryFunc" @@ -478,13 +492,34 @@ def __init__( docs, ) - def __call__(self, x, deg=False): - res = super().__call__(x) + def __call__(self, x, deg=False, out=None, order="K"): + res = super().__call__(x, out=out, order=order) if deg is True: res *= 180 / dpnp.pi return res +class DPNPImag(DPNPUnaryFunc): + """Class that implements dpnp.imag unary element-wise functions.""" + + def __init__( + self, + name, + result_type_resolver_fn, + unary_dp_impl_fn, + docs, + ): + super().__init__( + name, + result_type_resolver_fn, + unary_dp_impl_fn, + docs, + ) + + def __call__(self, x, out=None, order="K"): + return super().__call__(x, out=out, order=order) + + class DPNPReal(DPNPUnaryFunc): """Class that implements dpnp.real unary element-wise functions.""" @@ -502,9 +537,9 @@ def __init__( docs, ) - def __call__(self, x): + def __call__(self, x, out=None, order="K"): if numpy.iscomplexobj(x): - return super().__call__(x) + return super().__call__(x, out=out, order=order) return x @@ -606,3 +641,24 @@ def acceptance_fn_subtract( ) else: return True + + +def resolve_weak_types_2nd_arg_int(o1_dtype, o2_dtype, sycl_dev): + """ + The second weak dtype has to be upcasting up to default integer dtype + for a SYCL device where it is possible. + For other cases the default weak types resolving will be applied. + + """ + + if dtu._is_weak_dtype(o2_dtype): + o1_kind_num = dtu._strong_dtype_num_kind(o1_dtype) + o2_kind_num = dtu._weak_type_num_kind(o2_dtype) + if o2_kind_num < o1_kind_num: + if isinstance( + o2_dtype, (dtu.WeakBooleanType, dtu.WeakIntegralType) + ): + return o1_dtype, dpt.dtype( + dti.default_device_int_type(sycl_dev) + ) + return dtu._resolve_weak_types(o1_dtype, o2_dtype, sycl_dev) diff --git a/dpnp/dpnp_iface_mathematical.py b/dpnp/dpnp_iface_mathematical.py index 8198f24bb3cc..4a6bd9edc8ef 100644 --- a/dpnp/dpnp_iface_mathematical.py +++ b/dpnp/dpnp_iface_mathematical.py @@ -62,6 +62,7 @@ from .dpnp_algo.dpnp_elementwise_common import ( DPNPAngle, DPNPBinaryFunc, + DPNPImag, DPNPReal, DPNPRound, DPNPUnaryFunc, @@ -70,6 +71,7 @@ acceptance_fn_positive, acceptance_fn_sign, acceptance_fn_subtract, + resolve_weak_types_2nd_arg_int, ) from .dpnp_array import dpnp_array from .dpnp_utils import call_origin, get_usm_allocations @@ -107,6 +109,7 @@ "heaviside", "imag", "lcm", + "ldexp", "maximum", "minimum", "mod", @@ -488,6 +491,9 @@ def _process_ediff1d_args(arg, arg_name, ary_dtype, ary_sycl_queue, usm_type): ---------- x : {dpnp.ndarray, usm_ndarray} Input array, expected to have a complex-valued floating-point data type. +deg : bool, optional + Return angle in degrees if ``True``, radians if ``False``. + Default: ``False``. out : {None, dpnp.ndarray, usm_ndarray}, optional Output array to populate. Array must have the correct shape and the expected data type. @@ -1450,6 +1456,12 @@ def ediff1d(ary, to_end=None, to_begin=None): will have a data type that depends on the capabilities of the device on which the array resides. +Limitations +----------- +Parameters `where` and `subok` are supported with their default values. +Keyword argument `kwargs` is currently unsupported. +Otherwise ``NotImplementedError`` exception will be raised. + See Also -------- :obj:`dpnp.absolute` : Absolute values including `complex` types. @@ -1502,6 +1514,12 @@ def ediff1d(ary, to_end=None, to_begin=None): Otherwise the result is stored there and the return value `out` is a reference to that array. +Limitations +----------- +Parameters `where` and `subok` are supported with their default values. +Keyword argument `kwargs` is currently unsupported. +Otherwise ``NotImplementedError`` exception will be raised. + See Also -------- :obj:`dpnp.round` : Round to given number of decimals. @@ -2029,6 +2047,12 @@ def ediff1d(ary, to_end=None, to_begin=None): out : dpnp.ndarray The greatest common divisor of the absolute value of the inputs. +Limitations +----------- +Parameters `where` and `subok` are supported with their default values. +Keyword argument `kwargs` is currently unsupported. +Otherwise ``NotImplementedError`` exception will be raised. + See Also -------- :obj:`dpnp.lcm` : The lowest common multiple. @@ -2359,7 +2383,7 @@ def gradient(f, *varargs, axis=None, edge_order=1): array(1.) """ -imag = DPNPUnaryFunc( +imag = DPNPImag( "imag", ti._imag_result_type, ti._imag, @@ -2393,6 +2417,12 @@ def gradient(f, *varargs, axis=None, edge_order=1): out : dpnp.ndarray The lowest common multiple of the absolute value of the inputs. +Limitations +----------- +Parameters `where` and `subok` are supported with their default values. +Keyword argument `kwargs` is currently unsupported. +Otherwise ``NotImplementedError`` exception will be raised. + See Also -------- :obj:`dpnp.gcd` : The greatest common divisor. @@ -2415,6 +2445,68 @@ def gradient(f, *varargs, axis=None, edge_order=1): ) +_LDEXP_DOCSTRING = """ +Returns x1 * 2**x2, element-wise. + +The mantissas `x1` and exponents of two `x2` are used to construct floating point +numbers ``x1 * 2**x2``. + +For full documentation refer to :obj:`numpy.ldexp`. + +Parameters +---------- +x1 : {dpnp.ndarray, usm_ndarray, scalar} + Array of multipliers, expected to have floating-point data types. + Both inputs `x1` and `x2` can not be scalars at the same time. +x2 : {dpnp.ndarray, usm_ndarray, scalar} + Array of exponents of two, expected to have an integer data type. + Both inputs `x1` and `x2` can not be scalars at the same time. +out : {None, dpnp.ndarray, usm_ndarray}, optional + Output array to populate. Array must have the correct shape and + the expected data type. + Default: ``None``. +order : {"C", "F", "A", "K"}, optional + Memory layout of the newly output array, if parameter `out` is ``None``. + Default: ``"K"``. + +Returns +------- +out : dpnp.ndarray + The result of ``x1 * 2**x2``. + +Limitations +----------- +Parameters `where` and `subok` are supported with their default values. +Keyword argument `kwargs` is currently unsupported. +Otherwise ``NotImplementedError`` exception will be raised. + +See Also +-------- +:obj:`dpnp.frexp` : Return (y1, y2) from ``x = y1 * 2**y2``, inverse to :obj:`dpnp.ldexp`. + +Notes +----- +Complex dtypes are not supported, they will raise a ``TypeError``. + +:obj:`dpnp.ldexp` is useful as the inverse of :obj:`dpnp.frexp`, if used by +itself it is more clear to simply use the expression ``x1 * 2**x2``. + +Examples +-------- +>>> import dpnp as np +>>> np.ldexp(5, np.arange(4)) +array([ 5., 10., 20., 40.]) +""" + +ldexp = DPNPBinaryFunc( + "_ldexp", + ufi._ldexp_result_type, + ufi._ldexp, + _LDEXP_DOCSTRING, + weak_type_resolver=resolve_weak_types_2nd_arg_int, +) + + _MAXIMUM_DOCSTRING = """ Compares two input arrays `x1` and `x2` and returns a new array containing the element-wise maxima. @@ -3201,7 +3293,7 @@ def prod( Limitations ----------- -Parameters `where' and `subok` are supported with their default values. +Parameters `where` and `subok` are supported with their default values. Keyword argument `kwargs` is currently unsupported. Otherwise ``NotImplementedError`` exception will be raised. @@ -3237,6 +3329,13 @@ def prod( ---------- x : {dpnp.ndarray, usm_ndarray} Input array, expected to have numeric data type. +out : {None, dpnp.ndarray, usm_ndarray}, optional + Output array to populate. + Array must have the correct shape and the expected data type. + Default: ``None``. +order : {"C", "F", "A", "K"}, optional + Memory layout of the newly output array, if parameter `out` is ``None``. + Default: ``"K"``. Returns ------- @@ -3448,6 +3547,7 @@ def real_if_close(a, tol=100): Limitations ----------- +Parameters `where` and `subok` are supported with their default values. Keyword argument `kwargs` is currently unsupported. Otherwise ``NotImplementedError`` exception will be raised. diff --git a/tests/skipped_tests.tbl b/tests/skipped_tests.tbl index 32b50caf7ebe..1248c32c9acf 100644 --- a/tests/skipped_tests.tbl +++ b/tests/skipped_tests.tbl @@ -17,14 +17,6 @@ tests/test_umath.py::test_umaths[('divmod', 'ff')] tests/test_umath.py::test_umaths[('divmod', 'dd')] tests/test_umath.py::test_umaths[('frexp', 'f')] tests/test_umath.py::test_umaths[('frexp', 'd')] -tests/test_umath.py::test_umaths[('gcd', 'ii')] -tests/test_umath.py::test_umaths[('gcd', 'll')] -tests/test_umath.py::test_umaths[('lcm', 'ii')] -tests/test_umath.py::test_umaths[('lcm', 'll')] -tests/test_umath.py::test_umaths[('ldexp', 'fi')] -tests/test_umath.py::test_umaths[('ldexp', 'fl')] -tests/test_umath.py::test_umaths[('ldexp', 'di')] -tests/test_umath.py::test_umaths[('ldexp', 'dl')] tests/test_umath.py::test_umaths[('spacing', 'f')] tests/test_umath.py::test_umaths[('spacing', 'd')] diff --git a/tests/skipped_tests_gpu.tbl b/tests/skipped_tests_gpu.tbl index 3062c7f6498b..2ab14190fa78 100644 --- a/tests/skipped_tests_gpu.tbl +++ b/tests/skipped_tests_gpu.tbl @@ -24,14 +24,6 @@ tests/test_umath.py::test_umaths[('divmod', 'dd')] tests/test_umath.py::test_umaths[('floor_divide', 'ff')] tests/test_umath.py::test_umaths[('frexp', 'f')] tests/test_umath.py::test_umaths[('frexp', 'd')] -tests/test_umath.py::test_umaths[('gcd', 'ii')] -tests/test_umath.py::test_umaths[('gcd', 'll')] -tests/test_umath.py::test_umaths[('lcm', 'ii')] -tests/test_umath.py::test_umaths[('lcm', 'll')] -tests/test_umath.py::test_umaths[('ldexp', 'fi')] -tests/test_umath.py::test_umaths[('ldexp', 'fl')] -tests/test_umath.py::test_umaths[('ldexp', 'di')] -tests/test_umath.py::test_umaths[('ldexp', 'dl')] tests/test_umath.py::test_umaths[('spacing', 'f')] tests/test_umath.py::test_umaths[('spacing', 'd')] diff --git a/tests/test_mathematical.py b/tests/test_mathematical.py index 2df4b848f0d9..d615f827c28d 100644 --- a/tests/test_mathematical.py +++ b/tests/test_mathematical.py @@ -1289,6 +1289,86 @@ def test_op_multiple_dtypes(dtype1, func, dtype2, data): assert_allclose(result, expected) +class TestLdexp: + @pytest.mark.parametrize("mant_dt", get_float_dtypes()) + @pytest.mark.parametrize("exp_dt", get_integer_dtypes()) + def test_basic(self, mant_dt, exp_dt): + if ( + numpy.lib.NumpyVersion(numpy.__version__) < "2.0.0" + and exp_dt == numpy.int64 + and numpy.dtype("l") != numpy.int64 + ): + pytest.skip("numpy.ldexp doesn't have a loop for the input types") + + mant = numpy.array(2.0, dtype=mant_dt) + exp = numpy.array(3, dtype=exp_dt) + imant, iexp = dpnp.array(mant), dpnp.array(exp) + + result = dpnp.ldexp(imant, iexp) + expected = numpy.ldexp(mant, exp) + assert_almost_equal(result, expected) + + def test_float_scalar(self): + a = numpy.array(3) + ia = dpnp.array(a) + + result = dpnp.ldexp(2.0, ia) + expected = numpy.ldexp(2.0, a) + assert_almost_equal(result, expected) + + @pytest.mark.parametrize("max_min", ["max", "min"]) + def test_overflow(self, max_min): + exp_val = getattr(numpy.iinfo(numpy.dtype("l")), max_min) + + result = dpnp.ldexp(dpnp.array(2.0), exp_val) + with numpy.errstate(over="ignore"): + # we can't use here numpy.array(2.0), because NumPy 2.0 will cast + # `exp_val` to int32 dtype then and `OverflowError` will be raised + expected = numpy.ldexp(2.0, exp_val) + assert_equal(result, expected) + + @pytest.mark.parametrize("val", [numpy.nan, numpy.inf, -numpy.inf]) + def test_nan_int_mant(self, val): + mant = numpy.array(val) + imant = dpnp.array(mant) + + result = dpnp.ldexp(imant, 5) + expected = numpy.ldexp(mant, 5) + assert_equal(result, expected) + + def test_zero_exp(self): + exp = numpy.array(0) + iexp = dpnp.array(exp) + + result = dpnp.ldexp(-2.5, iexp) + expected = numpy.ldexp(-2.5, exp) + assert_equal(result, expected) + + @pytest.mark.parametrize("stride", [-4, -2, -1, 1, 2, 4]) + @pytest.mark.parametrize("dt", get_float_dtypes()) + def test_strides(self, stride, dt): + mant = numpy.array( + [0.125, 0.25, 0.5, 1.0, 1.0, 2.0, 4.0, 8.0], dtype=dt + ) + exp = numpy.array([3, 2, 1, 0, 0, -1, -2, -3], dtype="i") + out = numpy.zeros(8, dtype=dt) + imant, iexp, iout = dpnp.array(mant), dpnp.array(exp), dpnp.array(out) + + result = dpnp.ldexp(imant[::stride], iexp[::stride], out=iout[::stride]) + expected = numpy.ldexp(mant[::stride], exp[::stride], out=out[::stride]) + assert_equal(result, expected) + + def test_bool_exp(self): + result = dpnp.ldexp(3.7, dpnp.array(True)) + expected = numpy.ldexp(3.7, numpy.array(True)) + assert_almost_equal(result, expected) + + @pytest.mark.parametrize("xp", [dpnp, numpy]) + def test_uint64_exp(self, xp): + x = xp.array(4, dtype=numpy.uint64) + assert_raises((ValueError, TypeError), xp.ldexp, 7.3, x) + + @pytest.mark.parametrize( "rhs", [[[1, 2, 3], [4, 5, 6]], [2.0, 1.5, 1.0], 3, 0.3] ) diff --git a/tests/test_sycl_queue.py b/tests/test_sycl_queue.py index aa04497f453d..5179f97872d9 100644 --- a/tests/test_sycl_queue.py +++ b/tests/test_sycl_queue.py @@ -741,6 +741,11 @@ def test_reduce_hypot(device): [0, 1, 2, 3, 4, 5], [20, 20, 20, 20, 20, 20], ), + pytest.param( + "ldexp", + [5, 5, 5, 5, 5], + [0, 1, 2, 3, 4], + ), pytest.param("logaddexp", [[-1, 2, 5, 9]], [[4, -3, 2, -8]]), pytest.param("logaddexp2", [[-1, 2, 5, 9]], [[4, -3, 2, -8]]), pytest.param( diff --git a/tests/test_umath.py b/tests/test_umath.py index 067e212a8a90..8feec65ed23c 100644 --- a/tests/test_umath.py +++ b/tests/test_umath.py @@ -99,8 +99,18 @@ def test_umaths(test_cases): args = get_args(args_str, sh, xp=numpy) iargs = get_args(args_str, sh, xp=dpnp) - if umath == "reciprocal" and args[0].dtype in [numpy.int32, numpy.int64]: - pytest.skip("For integer input array, numpy.reciprocal returns zero.") + if umath == "reciprocal": + if args[0].dtype in [numpy.int32, numpy.int64]: + pytest.skip( + "For integer input array, numpy.reciprocal returns zero." + ) + elif umath == "ldexp": + if ( + numpy.lib.NumpyVersion(numpy.__version__) < "2.0.0" + and args[1].dtype == numpy.int64 + and numpy.dtype("l") != numpy.int64 + ): + pytest.skip("numpy.ldexp doesn't have a loop for the input types") # original expected = getattr(numpy, umath)(*args) diff --git a/tests/test_usm_type.py b/tests/test_usm_type.py index cfa65b9b1cef..8db3c5caab25 100644 --- a/tests/test_usm_type.py +++ b/tests/test_usm_type.py @@ -704,6 +704,11 @@ def test_1in_1out(func, data, usm_type): [0, 1, 2, 3, 4, 5], [20, 20, 20, 20, 20, 20], ), + pytest.param( + "ldexp", + [5, 5, 5, 5, 5], + [0, 1, 2, 3, 4], + ), pytest.param("logaddexp", [[-1, 2, 5, 9]], [[4, -3, 2, -8]]), pytest.param("logaddexp2", [[-1, 2, 5, 9]], [[4, -3, 2, -8]]), pytest.param("maximum", [0.0, 1.0, 2.0], [3.0, 4.0, 5.0]), diff --git a/tests/third_party/cupy/math_tests/test_floating.py b/tests/third_party/cupy/math_tests/test_floating.py index 9b72537247bb..b9e71391bbd2 100644 --- a/tests/third_party/cupy/math_tests/test_floating.py +++ b/tests/third_party/cupy/math_tests/test_floating.py @@ -29,7 +29,6 @@ def test_copysign_float(self, xp, dtype): b = xp.array([-xp.inf, -3, -0.0, 0, 3, xp.inf], dtype=dtype)[None, :] return xp.copysign(a, b) - @pytest.mark.skip("ldexp() is not implemented yet") @testing.for_float_dtypes(name="ftype") @testing.for_dtypes(["i", "l"], name="itype") @testing.numpy_cupy_array_equal()