Skip to content

Commit 9b3623f

Browse files
committed
ENH: test atan2, hypot, logaddexp with scalars
1 parent 7b2ed72 commit 9b3623f

File tree

1 file changed

+72
-10
lines changed

1 file changed

+72
-10
lines changed

array_api_tests/test_operators_and_elementwise_functions.py

Lines changed: 72 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -821,10 +821,30 @@ def test_atan(x):
821821
@given(*hh.two_mutual_arrays(dh.real_float_dtypes))
822822
def test_atan2(x1, x2):
823823
out = xp.atan2(x1, x2)
824-
ph.assert_dtype("atan2", in_dtype=[x1.dtype, x2.dtype], out_dtype=out.dtype)
825-
ph.assert_result_shape("atan2", in_shapes=[x1.shape, x2.shape], out_shape=out.shape)
826-
refimpl = cmath.atan2 if x1.dtype in dh.complex_dtypes else math.atan2
827-
binary_assert_against_refimpl("atan2", x1, x2, out, refimpl)
824+
_assert_correctness_binary(
825+
"atan",
826+
cmath.atan2 if x1.dtype in dh.complex_dtypes else math.atan2,
827+
in_dtypes=[x1.dtype, x2.dtype],
828+
in_shapes=[x1.shape, x2.shape],
829+
in_arrs=[x1, x2],
830+
out=out,
831+
)
832+
833+
834+
@pytest.mark.min_version("2024.12")
835+
@given(hh.array_and_py_scalar(dh.real_float_dtypes))
836+
def test_atan2_with_scalars(x1x2):
837+
x1, x2 = x1x2
838+
out = xp.atan2(x1, x2)
839+
in_dtypes, in_shapes, (x1a, x2a) = _convert_scalars_helper(x1, x2)
840+
_assert_correctness_binary(
841+
"atan2",
842+
cmath.atan2 if x1a.dtype in dh.complex_dtypes else math.atan2,
843+
in_dtypes=in_dtypes,
844+
in_shapes=in_shapes,
845+
in_arrs=[x1a, x2a],
846+
out=out,
847+
)
828848

829849

830850
@given(hh.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes()))
@@ -1290,11 +1310,31 @@ def test_greater_equal(ctx, data):
12901310
@given(*hh.two_mutual_arrays(dh.real_float_dtypes))
12911311
def test_hypot(x1, x2):
12921312
out = xp.hypot(x1, x2)
1293-
ph.assert_dtype("hypot", in_dtype=[x1.dtype, x2.dtype], out_dtype=out.dtype)
1294-
ph.assert_result_shape("hypot", in_shapes=[x1.shape, x2.shape], out_shape=out.shape)
1295-
binary_assert_against_refimpl("hypot", x1, x2, out, math.hypot)
1313+
_assert_correctness_binary(
1314+
"hypot",
1315+
math.hypot,
1316+
in_dtypes=[x1.dtype, x2.dtype],
1317+
in_shapes=[x1.shape, x2.shape],
1318+
in_arrs=[x1, x2],
1319+
out=out
1320+
)
12961321

12971322

1323+
@pytest.mark.min_version("2024.12")
1324+
@given(hh.array_and_py_scalar(dh.real_float_dtypes))
1325+
def test_hypot_with_scalars(x1x2):
1326+
x1, x2 = x1x2
1327+
out = xp.hypot(x1, x2)
1328+
in_dtypes, in_shapes, (x1a, x2a) = _convert_scalars_helper(x1, x2)
1329+
_assert_correctness_binary(
1330+
"hypot",
1331+
math.hypot,
1332+
in_dtypes=in_dtypes,
1333+
in_shapes=in_shapes,
1334+
in_arrs=(x1a, x2a),
1335+
out=out
1336+
)
1337+
12981338

12991339
@pytest.mark.min_version("2022.12")
13001340
@pytest.mark.skipif(hh.complex_dtypes.is_empty, reason="no complex data types to draw from")
@@ -1443,12 +1483,34 @@ def logaddexp_refimpl(l: float, r: float) -> float:
14431483
raise OverflowError
14441484

14451485

1486+
@pytest.mark.min_version("2023.12")
14461487
@given(*hh.two_mutual_arrays(dh.real_float_dtypes))
14471488
def test_logaddexp(x1, x2):
14481489
out = xp.logaddexp(x1, x2)
1449-
ph.assert_dtype("logaddexp", in_dtype=[x1.dtype, x2.dtype], out_dtype=out.dtype)
1450-
ph.assert_result_shape("logaddexp", in_shapes=[x1.shape, x2.shape], out_shape=out.shape)
1451-
binary_assert_against_refimpl("logaddexp", x1, x2, out, logaddexp_refimpl)
1490+
_assert_correctness_binary(
1491+
"logaddexp",
1492+
logaddexp_refimpl,
1493+
in_dtypes=[x1.dtype, x2.dtype],
1494+
in_shapes=[x1.shape, x2.shape],
1495+
in_arrs=[x1, x2],
1496+
out=out
1497+
)
1498+
1499+
1500+
@pytest.mark.min_version("2024.12")
1501+
@given(hh.array_and_py_scalar(dh.real_float_dtypes))
1502+
def test_logaddexp_with_scalars(x1x2):
1503+
x1, x2 = x1x2
1504+
out = xp.logaddexp(x1, x2)
1505+
in_dtypes, in_shapes, (x1a, x2a) = _convert_scalars_helper(x1, x2)
1506+
_assert_correctness_binary(
1507+
"logaddexp",
1508+
logaddexp_refimpl,
1509+
in_dtypes=in_dtypes,
1510+
in_shapes=in_shapes,
1511+
in_arrs=(x1a, x2a),
1512+
out=out
1513+
)
14521514

14531515

14541516
@given(hh.arrays(dtype=xp.bool, shape=hh.shapes()))

0 commit comments

Comments
 (0)