Skip to content

Commit bc1deed

Browse files
committed
ENH: expand testing of meshgrid: draw indexing, check shapes
1 parent 602b412 commit bc1deed

File tree

1 file changed

+15
-4
lines changed

1 file changed

+15
-4
lines changed

array_api_tests/test_creation_functions.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -538,8 +538,12 @@ def test_linspace(num, dtype, endpoint, data):
538538
raise
539539

540540

541-
@given(dtype=hh.numeric_dtypes, data=st.data())
542-
def test_meshgrid(dtype, data):
541+
@given(
542+
dtype=hh.numeric_dtypes,
543+
kw=hh.kwargs(indexing=st.sampled_from(["xy", "ij"])),
544+
data=st.data()
545+
)
546+
def test_meshgrid(dtype, kw, data):
543547
# The number and size of generated arrays is arbitrarily limited to prevent
544548
# meshgrid() running out of memory.
545549
shapes = data.draw(
@@ -557,11 +561,18 @@ def test_meshgrid(dtype, data):
557561
# sanity check
558562
# assert math.prod(math.prod(x.shape) for x in arrays) <= hh.MAX_ARRAY_SIZE
559563

560-
repro_snippet = ph.format_snippet(f"xp.meshgrid(*arrays) with {arrays = }")
564+
tgt_shape = [a.shape[0] for a in arrays]
565+
if len(tgt_shape) > 1 and kw.get('indexing', 'xy') == 'xy':
566+
tgt_shape[0], tgt_shape[1] = tgt_shape[1], tgt_shape[0]
567+
tgt_shape = tuple(tgt_shape)
568+
569+
repro_snippet = ph.format_snippet(f"xp.meshgrid(*arrays, **kw) with {arrays = } and {kw = }")
561570
try:
562-
out = xp.meshgrid(*arrays)
571+
out = xp.meshgrid(*arrays, **kw)
563572
for i, x in enumerate(out):
564573
ph.assert_dtype("meshgrid", in_dtype=dtype, out_dtype=x.dtype, repr_name=f"out[{i}].dtype")
574+
ph.assert_shape("meshgrid", out_shape=x.shape, expected=tgt_shape)
575+
565576
except Exception as exc:
566577
ph.add_note(exc, repro_snippet)
567578
raise

0 commit comments

Comments
 (0)