diff --git a/array_api_tests/test_creation_functions.py b/array_api_tests/test_creation_functions.py index c55b2da4..50dd5ada 100644 --- a/array_api_tests/test_creation_functions.py +++ b/array_api_tests/test_creation_functions.py @@ -13,6 +13,7 @@ from . import shape_helpers as sh from . import xps from .typing import DataType, Scalar +from . import api_version class frange(NamedTuple): @@ -560,6 +561,8 @@ def test_meshgrid(dtype, data): repro_snippet = ph.format_snippet(f"xp.meshgrid(*arrays) with {arrays = }") try: out = xp.meshgrid(*arrays) + + assert type(out) == list if api_version < "2025.12" else tuple for i, x in enumerate(out): ph.assert_dtype("meshgrid", in_dtype=dtype, out_dtype=x.dtype, repr_name=f"out[{i}].dtype") except Exception as exc: diff --git a/array_api_tests/test_data_type_functions.py b/array_api_tests/test_data_type_functions.py index f6613b60..92116474 100644 --- a/array_api_tests/test_data_type_functions.py +++ b/array_api_tests/test_data_type_functions.py @@ -12,7 +12,7 @@ from . import shape_helpers as sh from . import xps from .typing import DataType - +from . import api_version # TODO: test with complex dtypes def non_complex_dtypes(): @@ -107,6 +107,8 @@ def test_broadcast_arrays(shapes, data): try: out = xp.broadcast_arrays(*arrays) + assert type(out) == list if api_version < "2025.12" else tuple + expected_shape = sh.broadcast_shapes(*shapes) for i, x in enumerate(arrays): ph.assert_dtype( diff --git a/array_api_tests/test_inspection_functions.py b/array_api_tests/test_inspection_functions.py index d210535e..ae9362b5 100644 --- a/array_api_tests/test_inspection_functions.py +++ b/array_api_tests/test_inspection_functions.py @@ -3,6 +3,7 @@ from array_api_tests.dtype_helpers import available_kinds, dtype_names from . import xp +from . import api_version pytestmark = pytest.mark.min_version("2023.12") @@ -31,7 +32,7 @@ def test_devices(self): assert hasattr(out, "devices") assert hasattr(out, "default_device") - assert isinstance(out.devices(), list) + assert isinstance(out.devices(), list if api_version < "2025.12" else tuple) if out.default_device() is not None: # Per https://github.com/data-apis/array-api/issues/923 # default_device() can return None. Otherwise, it must be a valid device.