Skip to content

Commit 320202d

Browse files
vladislav.perevezentsevvlad-perevezentsev
authored andcommitted
Add TestPythonScalarConversion
1 parent 61505e4 commit 320202d

File tree

1 file changed

+44
-26
lines changed

1 file changed

+44
-26
lines changed

dpnp/tests/test_ndarray.py

Lines changed: 44 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
assert_allclose,
66
assert_array_equal,
77
assert_equal,
8+
assert_raises,
89
assert_raises_regex,
910
)
1011

@@ -17,6 +18,7 @@
1718
get_complex_dtypes,
1819
get_float_dtypes,
1920
has_support_aspect64,
21+
numpy_version,
2022
)
2123
from .third_party.cupy import testing
2224

@@ -530,34 +532,50 @@ def test_print_dpnp_zero_shape():
530532
assert result == expected
531533

532534

533-
# Numpy will raise an error when converting a.ndim > 0 to a scalar
534-
# TODO: Discuss dpnp behavior according to these future changes
535-
@pytest.mark.filterwarnings("ignore::DeprecationWarning")
536-
@pytest.mark.parametrize("func", [bool, float, int, complex])
537-
@pytest.mark.parametrize("shape", [tuple(), (1,), (1, 1), (1, 1, 1)])
538-
@pytest.mark.parametrize(
539-
"dtype", get_all_dtypes(no_float16=False, no_complex=True)
540-
)
541-
def test_scalar_type_casting(func, shape, dtype):
542-
a = numpy.full(shape, 5, dtype=dtype)
543-
ia = dpnp.full(shape, 5, dtype=dtype)
544-
assert func(a) == func(ia)
535+
class TestPythonScalarConversion:
536+
@pytest.mark.parametrize("shape", [tuple(), (1,), (1, 1), (1, 1, 1)])
537+
@pytest.mark.parametrize(
538+
"dtype", get_all_dtypes(no_float16=False, no_complex=True)
539+
)
540+
def test_bool_conversion(shape, dtype):
541+
a = numpy.full(shape, 5, dtype=dtype)
542+
ia = dpnp.full(shape, 5, dtype=dtype)
543+
assert bool(a) == bool(ia)
545544

545+
@pytest.mark.parametrize("shape", [tuple(), (1,), (1, 1), (1, 1, 1)])
546+
@pytest.mark.parametrize(
547+
"dtype", get_all_dtypes(no_float16=False, no_complex=True)
548+
)
549+
def test_bool_method_conversion(shape, dtype):
550+
a = numpy.full(shape, 5, dtype=dtype)
551+
ia = dpnp.full(shape, 5, dtype=dtype)
552+
assert a.__bool__() == ia.__bool__()
546553

547-
# Numpy will raise an error when converting a.ndim > 0 to a scalar
548-
# TODO: Discuss dpnp behavior according to these future changes
549-
@pytest.mark.filterwarnings("ignore::DeprecationWarning")
550-
@pytest.mark.parametrize(
551-
"method", ["__bool__", "__float__", "__int__", "__complex__"]
552-
)
553-
@pytest.mark.parametrize("shape", [tuple(), (1,), (1, 1), (1, 1, 1)])
554-
@pytest.mark.parametrize(
555-
"dtype", get_all_dtypes(no_float16=False, no_complex=True)
556-
)
557-
def test_scalar_type_casting_by_method(method, shape, dtype):
558-
a = numpy.full(shape, 4.7, dtype=dtype)
559-
ia = dpnp.full(shape, 4.7, dtype=dtype)
560-
assert_allclose(getattr(a, method)(), getattr(ia, method)(), rtol=1e-06)
554+
@pytest.mark.parametrize("func", [float, int, complex])
555+
@pytest.mark.parametrize("shape", [tuple(), (1,), (1, 1), (1, 1, 1)])
556+
@pytest.mark.parametrize(
557+
"dtype", get_all_dtypes(no_float16=False, no_complex=True)
558+
)
559+
def test_non_bool_conversion(func, shape, dtype):
560+
a = numpy.full(shape, 5, dtype=dtype)
561+
ia = dpnp.full(shape, 5, dtype=dtype)
562+
assert_raises(TypeError, func(ia))
563+
564+
if numpy_version() >= "2.4.0":
565+
assert_raises(TypeError, func(a))
566+
567+
@pytest.mark.parametrize("method", ["__float__", "__int__", "__complex__"])
568+
@pytest.mark.parametrize("shape", [tuple(), (1,), (1, 1), (1, 1, 1)])
569+
@pytest.mark.parametrize(
570+
"dtype", get_all_dtypes(no_float16=False, no_complex=True)
571+
)
572+
def test_non_bool_method_conversion(method, shape, dtype):
573+
a = numpy.full(shape, 5, dtype=dtype)
574+
ia = dpnp.full(shape, 5, dtype=dtype)
575+
assert_raises(TypeError, getattr(ia, method)())
576+
577+
if numpy_version() >= "2.4.0":
578+
assert_raises(TypeError, getattr(a, method)())
561579

562580

563581
@pytest.mark.parametrize("shape", [(1,), (1, 1), (1, 1, 1)])

0 commit comments

Comments
 (0)