diff --git a/array_api_tests/_array_module.py b/array_api_tests/_array_module.py index 1c52a983..c3bd341e 100644 --- a/array_api_tests/_array_module.py +++ b/array_api_tests/_array_module.py @@ -35,6 +35,8 @@ def __repr__(self): _funcs += ["take", "isdtype", "conj", "imag", "real"] # TODO: bump spec and update array-api-tests to new spec layout _top_level_attrs = _dtypes + _constants + _funcs + stubs.EXTENSIONS + ["fft"] +_top_level_attrs += ['broadcast_shapes'] # FIXME: until the spec is not updated + for attr in _top_level_attrs: try: globals()[attr] = getattr(xp, attr) diff --git a/array_api_tests/test_data_type_functions.py b/array_api_tests/test_data_type_functions.py index f6613b60..6e9ab1b4 100644 --- a/array_api_tests/test_data_type_functions.py +++ b/array_api_tests/test_data_type_functions.py @@ -128,6 +128,31 @@ def test_broadcast_arrays(shapes, data): raise + +class TestBroadcastShapes: + + @given(shapes=st.integers(1, 5).flatmap(hh.mutually_broadcastable_shapes)) + def test_broadcast_shapes(self, shapes): + repro_snippet = ph.format_snippet(f"xp.broadcast_shapes(*shapes) with {shapes = }") + try: + out_shape = xp.broadcast_shapes(*shapes) + expected_shape = sh.broadcast_shapes(*shapes) + assert out_shape == expected_shape + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise + + def test_empty(self): + assert xp.broadcast_shapes() == () + + @given(shapes=hh.mutually_broadcastable_shapes(2, min_dims=1, min_side=3)) + def test_error(self, shapes): + s1, s2 = shapes + s1 = s1[:-1] + (s1[-1] + 1,) + with pytest.raises(ValueError): + xp.broadcast_shapes(s1, s2) + + @given(x=hh.arrays(dtype=hh.all_dtypes, shape=hh.shapes()), data=st.data()) def test_broadcast_to(x, data): shape = data.draw(