diff --git a/.github/workflows/array-api-skips.txt b/.github/workflows/array-api-skips.txt index f9de622d05fd..6c108a165eea 100644 --- a/.github/workflows/array-api-skips.txt +++ b/.github/workflows/array-api-skips.txt @@ -3,11 +3,6 @@ # no 'uint8' dtype array_api_tests/test_array_object.py::test_getitem_masking -# no 'isdtype' function -array_api_tests/test_data_type_functions.py::test_isdtype -array_api_tests/test_has_names.py::test_has_names[data_type-isdtype] -array_api_tests/test_signatures.py::test_func_signature[isdtype] - # missing unique-like functions array_api_tests/test_has_names.py::test_has_names[set-unique_all] array_api_tests/test_has_names.py::test_has_names[set-unique_counts] diff --git a/doc/reference/dtype.rst b/doc/reference/dtype.rst index a6e200662797..7b71fb1cc721 100644 --- a/doc/reference/dtype.rst +++ b/doc/reference/dtype.rst @@ -14,7 +14,6 @@ Data type routines dpnp.min_scalar_type dpnp.result_type dpnp.common_type - dpnp.obj2sctype Creating data types ------------------- diff --git a/dpnp/dpnp_iface_types.py b/dpnp/dpnp_iface_types.py index 3c72af138591..97a90d213b45 100644 --- a/dpnp/dpnp_iface_types.py +++ b/dpnp/dpnp_iface_types.py @@ -64,6 +64,7 @@ "integer", "intc", "intp", + "isdtype", "issubdtype", "is_type_supported", "nan", @@ -194,11 +195,66 @@ def iinfo(dtype): smallest representable number. """ + if isinstance(dtype, dpnp_array): dtype = dtype.dtype return dpt.iinfo(dtype) +def isdtype(dtype, kind): + """ + Returns a boolean indicating whether a provided `dtype` is + of a specified data type `kind`. + + Parameters + ---------- + dtype : dtype + The input dtype. + kind : {dtype, str, tuple of dtypes or strs} + The input dtype or dtype kind. Allowed dtype kinds are: + + * ``'bool'`` : boolean kind + * ``'signed integer'`` : signed integer data types + * ``'unsigned integer'`` : unsigned integer data types + * ``'integral'`` : integer data types + * ``'real floating'`` : real-valued floating-point data types + * ``'complex floating'`` : complex floating-point data types + * ``'numeric'`` : numeric data types + + Returns + ------- + out : bool + A boolean indicating whether a provided `dtype` is of a specified data + type `kind`. + + See Also + -------- + :obj:`dpnp.issubdtype` : Test if the first argument is a type code + lower/equal in type hierarchy. + + Examples + -------- + >>> import dpnp as np + >>> np.isdtype(np.float32, np.float64) + False + >>> np.isdtype(np.float32, "real floating") + True + >>> np.isdtype(np.complex128, ("real floating", "complex floating")) + True + + """ + + if isinstance(dtype, type): + dtype = dpt.dtype(dtype) + + if isinstance(kind, type): + kind = dpt.dtype(kind) + elif isinstance(kind, tuple): + kind = tuple(dpt.dtype(k) if isinstance(k, type) else k for k in kind) + + return dpt.isdtype(dtype, kind) + + def issubdtype(arg1, arg2): """ Returns ``True`` if the first argument is a type code lower/equal diff --git a/dpnp/tests/test_dtype_routines.py b/dpnp/tests/test_dtype_routines.py new file mode 100644 index 000000000000..8fab16c70936 --- /dev/null +++ b/dpnp/tests/test_dtype_routines.py @@ -0,0 +1,65 @@ +import numpy +import pytest +from numpy.testing import assert_raises_regex + +import dpnp + +from .helper import numpy_version + +if numpy_version() >= "2.0.0": + from numpy._core.numerictypes import sctypes +else: + from numpy.core.numerictypes import sctypes + + +class TestIsDType: + dtype_group = { + "signed integer": sctypes["int"], + "unsigned integer": sctypes["uint"], + "integral": sctypes["int"] + sctypes["uint"], + "real floating": sctypes["float"], + "complex floating": sctypes["complex"], + "numeric": ( + sctypes["int"] + + sctypes["uint"] + + sctypes["float"] + + sctypes["complex"] + ), + } + + @pytest.mark.parametrize( + "dt, close_dt", + [ + # TODO: replace with (dpnp.uint64, dpnp.uint32) once available + (dpnp.int64, dpnp.int32), + (numpy.uint64, numpy.uint32), + (dpnp.float64, dpnp.float32), + (dpnp.complex128, dpnp.complex64), + ], + ) + @pytest.mark.parametrize("dt_group", [None] + list(dtype_group.keys())) + def test_basic(self, dt, close_dt, dt_group): + # First check if same dtypes return "True" and different ones + # give "False" (even if they're close in the dtype hierarchy). + if dt_group is None: + assert dpnp.isdtype(dt, dt) + assert not dpnp.isdtype(dt, close_dt) + assert dpnp.isdtype(dt, (dt, close_dt)) + + # Check that dtype and a dtype group that it belongs to return "True", + # and "False" otherwise. + elif dt in self.dtype_group[dt_group]: + assert dpnp.isdtype(dt, dt_group) + assert dpnp.isdtype(dt, (close_dt, dt_group)) + else: + assert not dpnp.isdtype(dt, dt_group) + + def test_invalid_args(self): + with assert_raises_regex(TypeError, r"Expected instance of.*"): + dpnp.isdtype("int64", dpnp.int64) + + with assert_raises_regex(TypeError, r"Unsupported data type kind:.*"): + dpnp.isdtype(dpnp.int64, 1) + + with assert_raises_regex(ValueError, r"Unrecognized data type kind:.*"): + dpnp.isdtype(dpnp.int64, "int64")