From 06b5e8d889b99af384f3fc16729ee6287aaf2923 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Thu, 21 Nov 2024 18:47:17 +0200 Subject: [PATCH] ENH: add result_type_false: check the minimum promotion rules according to the spec Conforming array libraries may extend the minimum promotion rules. If they do, they will fail this new test and may want to xfail it. --- array_api_tests/hypothesis_helpers.py | 21 ++++++++++++++++++++- array_api_tests/test_data_type_functions.py | 10 ++++++++++ 2 files changed, 30 insertions(+), 1 deletion(-) diff --git a/array_api_tests/hypothesis_helpers.py b/array_api_tests/hypothesis_helpers.py index d9fdabd5..0b633ce7 100644 --- a/array_api_tests/hypothesis_helpers.py +++ b/array_api_tests/hypothesis_helpers.py @@ -10,7 +10,7 @@ from hypothesis import assume, reject from hypothesis.strategies import (SearchStrategy, booleans, composite, floats, integers, just, lists, none, one_of, - sampled_from, shared, builds) + sampled_from, shared, builds, permutations) from . import _array_module as xp, api_version from . import array_helpers as ah @@ -148,6 +148,25 @@ def mutually_promotable_dtypes( return one_of(strats).map(tuple) +@composite +def mutually_non_promotable_dtypes( + draw, + max_size: Optional[int] = 2, +) -> Sequence[Tuple[DataType, ...]]: + """Generate a pair of dtypes which cannot be promoted.""" + assert max_size == 2 + + _categories = [ + (xp.bool,), + dh.uint_dtypes + dh.int_dtypes, + dh.real_float_dtypes + dh.complex_dtypes + ] + cat_st = permutations(_categories).map(lambda s: s[:2]) + cat_from, cat_to = draw(cat_st) + from_, to = draw(sampled_from(cat_from)), draw(sampled_from(cat_to)) + return from_, to + + class OnewayPromotableDtypes(NamedTuple): input_dtype: DataType result_dtype: DataType diff --git a/array_api_tests/test_data_type_functions.py b/array_api_tests/test_data_type_functions.py index 34c40024..2ae98329 100644 --- a/array_api_tests/test_data_type_functions.py +++ b/array_api_tests/test_data_type_functions.py @@ -204,3 +204,13 @@ def test_isdtype(dtype, kind): def test_result_type(dtypes): out = xp.result_type(*dtypes) ph.assert_dtype("result_type", in_dtype=dtypes, out_dtype=out, repr_name="out") + + +@given(hh.mutually_non_promotable_dtypes(2)) +def test_result_type_false(dtypes): + """Test _very_ strict promotion rules according to the spec. + Conforming array libraries may extend the promotion rules, and + then they'll need to xfail this test. + """ + with pytest.raises(TypeError): + xp.result_type(*dtypes)