@@ -690,6 +690,31 @@ def binary_param_assert_against_refimpl(
690
690
)
691
691
692
692
693
+ def _convert_scalars_helper (x1 , x2 ):
694
+ """Convert python scalar to arrays, record the shapes/dtypes of arrays.
695
+
696
+ For inputs being scalars or arrays, return the dtypes and shapes of array arguments,
697
+ and all arguments converted to arrays.
698
+
699
+ dtypes are separate to help distinguishing between
700
+ `py_scalar + f32_array -> f32_array` and `f64_array + f32_array -> f64_array`
701
+ """
702
+ if dh .is_scalar (x1 ):
703
+ in_dtypes = [x2 .dtype ]
704
+ in_shapes = [x2 .shape ]
705
+ x1a , x2a = xp .asarray (x1 ), x2
706
+ elif dh .is_scalar (x2 ):
707
+ in_dtypes = [x1 .dtype ]
708
+ in_shapes = [x1 .shape ]
709
+ x1a , x2a = x1 , xp .asarray (x2 )
710
+ else :
711
+ in_dtypes = [x1 .dtype , x2 .dtype ]
712
+ in_shapes = [x1 .shape , x2 .shape ]
713
+ x1a , x2a = x1 , x2
714
+
715
+ return in_dtypes , in_shapes , (x1a , x2a )
716
+
717
+
693
718
@pytest .mark .parametrize ("ctx" , make_unary_params ("abs" , dh .numeric_dtypes ))
694
719
@given (data = st .data ())
695
720
def test_abs (ctx , data ):
@@ -1468,13 +1493,27 @@ def test_maximum(x1, x2):
1468
1493
binary_assert_against_refimpl ("maximum" , x1 , x2 , out , max , strict_check = True )
1469
1494
1470
1495
1496
+ def _assert_correctness_binary (name , in_dtypes , in_shapes , in_arrs , out ):
1497
+ x1a , x2a = in_arrs
1498
+ ph .assert_dtype (name , in_dtype = in_dtypes , out_dtype = out .dtype )
1499
+ ph .assert_result_shape (name , in_shapes = in_shapes , out_shape = out .shape )
1500
+ binary_assert_against_refimpl (name , x1a , x2a , out , min , strict_check = True )
1501
+
1502
+
1471
1503
@pytest .mark .min_version ("2023.12" )
1472
1504
@given (* hh .two_mutual_arrays (dh .real_float_dtypes ))
1473
1505
def test_minimum (x1 , x2 ):
1474
1506
out = xp .minimum (x1 , x2 )
1475
- ph .assert_dtype ("minimum" , in_dtype = [x1 .dtype , x2 .dtype ], out_dtype = out .dtype )
1476
- ph .assert_result_shape ("minimum" , in_shapes = [x1 .shape , x2 .shape ], out_shape = out .shape )
1477
- binary_assert_against_refimpl ("minimum" , x1 , x2 , out , min , strict_check = True )
1507
+ _assert_correctness_binary ("minimum" , [x1 .dtype , x2 .dtype ], [x1 .shape , x2 .shape ], (x1 , x2 ), out )
1508
+
1509
+
1510
+ @pytest .mark .min_version ("2024.12" )
1511
+ @given (hh .array_and_py_scalar (dh .real_float_dtypes ))
1512
+ def test_minimum_with_scalars (x1x2 ):
1513
+ x1 , x2 = x1x2
1514
+ out = xp .minimum (x1 , x2 )
1515
+ in_dtypes , in_shapes , (x1a , x2a ) = _convert_scalars_helper (x1 , x2 )
1516
+ _assert_correctness_binary ("minimum" , in_dtypes , in_shapes , (x1a , x2a ), out )
1478
1517
1479
1518
1480
1519
@pytest .mark .parametrize ("ctx" , make_binary_params ("multiply" , dh .numeric_dtypes ))
0 commit comments