diff --git a/array_api_tests/hypothesis_helpers.py b/array_api_tests/hypothesis_helpers.py index 7e7294f8..1bab8d34 100644 --- a/array_api_tests/hypothesis_helpers.py +++ b/array_api_tests/hypothesis_helpers.py @@ -449,7 +449,7 @@ def two_broadcastable_shapes(draw): ) @composite -def scalars(draw, dtypes, finite=False): +def scalars(draw, dtypes, finite=False, **kwds): """ Strategy to generate a scalar that matches a dtype strategy @@ -463,12 +463,12 @@ def scalars(draw, dtypes, finite=False): return draw(booleans()) elif dtype == float64: if finite: - return draw(floats(allow_nan=False, allow_infinity=False)) - return draw(floats()) + return draw(floats(allow_nan=False, allow_infinity=False, **kwds)) + return draw(floats(), **kwds) elif dtype == float32: if finite: - return draw(floats(width=32, allow_nan=False, allow_infinity=False)) - return draw(floats(width=32)) + return draw(floats(width=32, allow_nan=False, allow_infinity=False, **kwds)) + return draw(floats(width=32, **kwds)) elif dtype == complex64: if finite: return draw(complex_numbers(width=32, allow_nan=False, allow_infinity=False)) @@ -591,8 +591,16 @@ def two_mutual_arrays( def array_and_py_scalar(draw, dtypes): """Draw a pair: (array, scalar) or (scalar, array).""" dtype = draw(sampled_from(dtypes)) - scalar_var = draw(scalars(just(dtype), finite=True)) - array_var = draw(arrays(dtype, shape=shapes(min_dims=1))) + + scalar_var = draw(scalars(just(dtype), finite=True, + **{'min_value': 1/ (2<<5), 'max_value': 2<<5} + )) + + elements={} + if dtype in dh.real_float_dtypes: + elements = {'allow_nan': False, 'allow_infinity': False, + 'min_value': 1.0 / (2<<5), 'max_value': 2<<5} + array_var = draw(arrays(dtype, shape=shapes(min_dims=1), elements=elements)) if draw(booleans()): return scalar_var, array_var diff --git a/array_api_tests/test_operators_and_elementwise_functions.py b/array_api_tests/test_operators_and_elementwise_functions.py index 129f5c21..d79fd17b 100644 --- a/array_api_tests/test_operators_and_elementwise_functions.py +++ b/array_api_tests/test_operators_and_elementwise_functions.py @@ -1816,19 +1816,11 @@ def _filter_zero(x): (xp.less_equal, "les_equal", operator.le, {}, xp.bool), (xp.greater, "greater", operator.gt, {}, xp.bool), (xp.greater_equal, "greater_equal", operator.ge, {}, xp.bool), - (xp.remainder, "remainder", operator.mod, {}, None), - (xp.floor_divide, "floor_divide", operator.floordiv, {}, None), ], ids=lambda func_data: func_data[1] # use names for test IDs ) @given(x1x2=hh.array_and_py_scalar(dh.real_float_dtypes)) def test_binary_with_scalars_real(func_data, x1x2): - - if func_data[1] == "remainder": - assume(_filter_zero(x1x2[1])) - if func_data[1] == "floor_divide": - assume(_filter_zero(x1x2[0]) and _filter_zero(x1x2[1])) - _check_binary_with_scalars(func_data, x1x2) @@ -1847,6 +1839,24 @@ def test_binary_with_scalars_bool(func_data, x1x2): _check_binary_with_scalars(func_data, x1x2) +@pytest.mark.min_version("2024.12") +@pytest.mark.parametrize('func_data', + # xp_func, name, refimpl, kwargs, expected_dtype + [ + + (xp.floor_divide, "floor_divide", operator.floordiv, {}, None), + (xp.remainder, "remainder", operator.mod, {}, None), + ], + ids=lambda func_data: func_data[1] # use names for test IDs +) +@given(x1x2=hh.array_and_py_scalar([xp.int64])) +def test_binary_with_scalars_int(func_data, x1x2): + + assume(_filter_zero(x1x2[1])) + assume(_filter_zero(x1x2[0]) and _filter_zero(x1x2[1])) + _check_binary_with_scalars(func_data, x1x2) + + @pytest.mark.min_version("2024.12") @pytest.mark.parametrize('func_data', # xp_func, name, refimpl, kwargs, expected_dtype