Skip to content

Commit b5efd40

Browse files
committed
ENH: test minimium() with python scalars
1 parent 0b89c52 commit b5efd40

File tree

3 files changed

+57
-3
lines changed

3 files changed

+57
-3
lines changed

array_api_tests/dtype_helpers.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,7 @@ def get_scalar_type(dtype: DataType) -> ScalarType:
198198
def is_scalar(x):
199199
return isinstance(x, (int, float, complex, bool))
200200

201+
201202
def _make_dtype_mapping_from_names(mapping: Dict[str, Any]) -> EqualityMapping:
202203
dtype_value_pairs = []
203204
for name, value in mapping.items():

array_api_tests/hypothesis_helpers.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -571,6 +571,20 @@ def two_mutual_arrays(
571571
)
572572
return arrays1, arrays2
573573

574+
575+
@composite
576+
def array_and_py_scalar(draw, dtypes):
577+
"""Draw a pair: (array, scalar) or (scalar, array)."""
578+
dtype = draw(sampled_from(dtypes))
579+
scalar_var = draw(scalars(just(dtype), finite=True))
580+
array_var = draw(arrays(dtype, shape=shapes(min_dims=1)))
581+
582+
if draw(booleans()):
583+
return scalar_var, array_var
584+
else:
585+
return array_var, scalar_var
586+
587+
574588
@composite
575589
def kwargs(draw, **kw):
576590
"""

array_api_tests/test_operators_and_elementwise_functions.py

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -690,6 +690,31 @@ def binary_param_assert_against_refimpl(
690690
)
691691

692692

693+
def _convert_scalars_helper(x1, x2):
694+
"""Convert python scalar to arrays, record the shapes/dtypes of arrays.
695+
696+
For inputs being scalars or arrays, return the dtypes and shapes of array arguments,
697+
and all arguments converted to arrays.
698+
699+
dtypes are separate to help distinguishing between
700+
`py_scalar + f32_array -> f32_array` and `f64_array + f32_array -> f64_array`
701+
"""
702+
if dh.is_scalar(x1):
703+
in_dtypes = [x2.dtype]
704+
in_shapes = [x2.shape]
705+
x1a, x2a = xp.asarray(x1), x2
706+
elif dh.is_scalar(x2):
707+
in_dtypes = [x1.dtype]
708+
in_shapes = [x1.shape]
709+
x1a, x2a = x1, xp.asarray(x2)
710+
else:
711+
in_dtypes = [x1.dtype, x2.dtype]
712+
in_shapes = [x1.shape, x2.shape]
713+
x1a, x2a = x1, x2
714+
715+
return in_dtypes, in_shapes, (x1a, x2a)
716+
717+
693718
@pytest.mark.parametrize("ctx", make_unary_params("abs", dh.numeric_dtypes))
694719
@given(data=st.data())
695720
def test_abs(ctx, data):
@@ -1468,13 +1493,27 @@ def test_maximum(x1, x2):
14681493
binary_assert_against_refimpl("maximum", x1, x2, out, max, strict_check=True)
14691494

14701495

1496+
def _assert_correctness_binary(name, in_dtypes, in_shapes, in_arrs, out):
1497+
x1a, x2a = in_arrs
1498+
ph.assert_dtype(name, in_dtype=in_dtypes, out_dtype=out.dtype)
1499+
ph.assert_result_shape(name, in_shapes=in_shapes, out_shape=out.shape)
1500+
binary_assert_against_refimpl(name, x1a, x2a, out, min, strict_check=True)
1501+
1502+
14711503
@pytest.mark.min_version("2023.12")
14721504
@given(*hh.two_mutual_arrays(dh.real_float_dtypes))
14731505
def test_minimum(x1, x2):
14741506
out = xp.minimum(x1, x2)
1475-
ph.assert_dtype("minimum", in_dtype=[x1.dtype, x2.dtype], out_dtype=out.dtype)
1476-
ph.assert_result_shape("minimum", in_shapes=[x1.shape, x2.shape], out_shape=out.shape)
1477-
binary_assert_against_refimpl("minimum", x1, x2, out, min, strict_check=True)
1507+
_assert_correctness_binary("minimum", [x1.dtype, x2.dtype], [x1.shape, x2.shape], (x1, x2), out)
1508+
1509+
1510+
@pytest.mark.min_version("2024.12")
1511+
@given(hh.array_and_py_scalar(dh.real_float_dtypes))
1512+
def test_minimum_with_scalars(x1x2):
1513+
x1, x2 = x1x2
1514+
out = xp.minimum(x1, x2)
1515+
in_dtypes, in_shapes, (x1a, x2a) = _convert_scalars_helper(x1, x2)
1516+
_assert_correctness_binary("minimum", in_dtypes, in_shapes, (x1a, x2a), out)
14781517

14791518

14801519
@pytest.mark.parametrize("ctx", make_binary_params("multiply", dh.numeric_dtypes))

0 commit comments

Comments
 (0)