@@ -821,10 +821,30 @@ def test_atan(x):
821
821
@given (* hh .two_mutual_arrays (dh .real_float_dtypes ))
822
822
def test_atan2 (x1 , x2 ):
823
823
out = xp .atan2 (x1 , x2 )
824
- ph .assert_dtype ("atan2" , in_dtype = [x1 .dtype , x2 .dtype ], out_dtype = out .dtype )
825
- ph .assert_result_shape ("atan2" , in_shapes = [x1 .shape , x2 .shape ], out_shape = out .shape )
826
- refimpl = cmath .atan2 if x1 .dtype in dh .complex_dtypes else math .atan2
827
- binary_assert_against_refimpl ("atan2" , x1 , x2 , out , refimpl )
824
+ _assert_correctness_binary (
825
+ "atan" ,
826
+ cmath .atan2 if x1 .dtype in dh .complex_dtypes else math .atan2 ,
827
+ in_dtypes = [x1 .dtype , x2 .dtype ],
828
+ in_shapes = [x1 .shape , x2 .shape ],
829
+ in_arrs = [x1 , x2 ],
830
+ out = out ,
831
+ )
832
+
833
+
834
+ @pytest .mark .min_version ("2024.12" )
835
+ @given (hh .array_and_py_scalar (dh .real_float_dtypes ))
836
+ def test_atan2_with_scalars (x1x2 ):
837
+ x1 , x2 = x1x2
838
+ out = xp .atan2 (x1 , x2 )
839
+ in_dtypes , in_shapes , (x1a , x2a ) = _convert_scalars_helper (x1 , x2 )
840
+ _assert_correctness_binary (
841
+ "atan2" ,
842
+ cmath .atan2 if x1a .dtype in dh .complex_dtypes else math .atan2 ,
843
+ in_dtypes = in_dtypes ,
844
+ in_shapes = in_shapes ,
845
+ in_arrs = [x1a , x2a ],
846
+ out = out ,
847
+ )
828
848
829
849
830
850
@given (hh .arrays (dtype = hh .all_floating_dtypes (), shape = hh .shapes ()))
@@ -1290,11 +1310,31 @@ def test_greater_equal(ctx, data):
1290
1310
@given (* hh .two_mutual_arrays (dh .real_float_dtypes ))
1291
1311
def test_hypot (x1 , x2 ):
1292
1312
out = xp .hypot (x1 , x2 )
1293
- ph .assert_dtype ("hypot" , in_dtype = [x1 .dtype , x2 .dtype ], out_dtype = out .dtype )
1294
- ph .assert_result_shape ("hypot" , in_shapes = [x1 .shape , x2 .shape ], out_shape = out .shape )
1295
- binary_assert_against_refimpl ("hypot" , x1 , x2 , out , math .hypot )
1313
+ _assert_correctness_binary (
1314
+ "hypot" ,
1315
+ math .hypot ,
1316
+ in_dtypes = [x1 .dtype , x2 .dtype ],
1317
+ in_shapes = [x1 .shape , x2 .shape ],
1318
+ in_arrs = [x1 , x2 ],
1319
+ out = out
1320
+ )
1296
1321
1297
1322
1323
+ @pytest .mark .min_version ("2024.12" )
1324
+ @given (hh .array_and_py_scalar (dh .real_float_dtypes ))
1325
+ def test_hypot_with_scalars (x1x2 ):
1326
+ x1 , x2 = x1x2
1327
+ out = xp .hypot (x1 , x2 )
1328
+ in_dtypes , in_shapes , (x1a , x2a ) = _convert_scalars_helper (x1 , x2 )
1329
+ _assert_correctness_binary (
1330
+ "hypot" ,
1331
+ math .hypot ,
1332
+ in_dtypes = in_dtypes ,
1333
+ in_shapes = in_shapes ,
1334
+ in_arrs = (x1a , x2a ),
1335
+ out = out
1336
+ )
1337
+
1298
1338
1299
1339
@pytest .mark .min_version ("2022.12" )
1300
1340
@pytest .mark .skipif (hh .complex_dtypes .is_empty , reason = "no complex data types to draw from" )
@@ -1443,12 +1483,34 @@ def logaddexp_refimpl(l: float, r: float) -> float:
1443
1483
raise OverflowError
1444
1484
1445
1485
1486
+ @pytest .mark .min_version ("2023.12" )
1446
1487
@given (* hh .two_mutual_arrays (dh .real_float_dtypes ))
1447
1488
def test_logaddexp (x1 , x2 ):
1448
1489
out = xp .logaddexp (x1 , x2 )
1449
- ph .assert_dtype ("logaddexp" , in_dtype = [x1 .dtype , x2 .dtype ], out_dtype = out .dtype )
1450
- ph .assert_result_shape ("logaddexp" , in_shapes = [x1 .shape , x2 .shape ], out_shape = out .shape )
1451
- binary_assert_against_refimpl ("logaddexp" , x1 , x2 , out , logaddexp_refimpl )
1490
+ _assert_correctness_binary (
1491
+ "logaddexp" ,
1492
+ logaddexp_refimpl ,
1493
+ in_dtypes = [x1 .dtype , x2 .dtype ],
1494
+ in_shapes = [x1 .shape , x2 .shape ],
1495
+ in_arrs = [x1 , x2 ],
1496
+ out = out
1497
+ )
1498
+
1499
+
1500
+ @pytest .mark .min_version ("2024.12" )
1501
+ @given (hh .array_and_py_scalar (dh .real_float_dtypes ))
1502
+ def test_logaddexp_with_scalars (x1x2 ):
1503
+ x1 , x2 = x1x2
1504
+ out = xp .logaddexp (x1 , x2 )
1505
+ in_dtypes , in_shapes , (x1a , x2a ) = _convert_scalars_helper (x1 , x2 )
1506
+ _assert_correctness_binary (
1507
+ "logaddexp" ,
1508
+ logaddexp_refimpl ,
1509
+ in_dtypes = in_dtypes ,
1510
+ in_shapes = in_shapes ,
1511
+ in_arrs = (x1a , x2a ),
1512
+ out = out
1513
+ )
1452
1514
1453
1515
1454
1516
@given (hh .arrays (dtype = xp .bool , shape = hh .shapes ()))
0 commit comments