diff --git a/array_api_tests/test_creation_functions.py b/array_api_tests/test_creation_functions.py index c55b2da4..568dd7ce 100644 --- a/array_api_tests/test_creation_functions.py +++ b/array_api_tests/test_creation_functions.py @@ -538,8 +538,12 @@ def test_linspace(num, dtype, endpoint, data): raise -@given(dtype=hh.numeric_dtypes, data=st.data()) -def test_meshgrid(dtype, data): +@given( + dtype=hh.numeric_dtypes, + kw=hh.kwargs(indexing=st.sampled_from(["xy", "ij"])), + data=st.data() +) +def test_meshgrid(dtype, kw, data): # The number and size of generated arrays is arbitrarily limited to prevent # meshgrid() running out of memory. shapes = data.draw( @@ -557,11 +561,17 @@ def test_meshgrid(dtype, data): # sanity check # assert math.prod(math.prod(x.shape) for x in arrays) <= hh.MAX_ARRAY_SIZE - repro_snippet = ph.format_snippet(f"xp.meshgrid(*arrays) with {arrays = }") + tgt_shape = [a.shape[0] for a in arrays] + if len(tgt_shape) > 1 and kw.get('indexing', 'xy') == 'xy': + tgt_shape[0], tgt_shape[1] = tgt_shape[1], tgt_shape[0] + tgt_shape = tuple(tgt_shape) + + repro_snippet = ph.format_snippet(f"xp.meshgrid(*arrays, **kw) with {arrays = } and {kw = }") try: - out = xp.meshgrid(*arrays) + out = xp.meshgrid(*arrays, **kw) for i, x in enumerate(out): ph.assert_dtype("meshgrid", in_dtype=dtype, out_dtype=x.dtype, repr_name=f"out[{i}].dtype") + ph.assert_shape("meshgrid", out_shape=x.shape, expected=tgt_shape) except Exception as exc: ph.add_note(exc, repro_snippet) raise