@@ -538,8 +538,12 @@ def test_linspace(num, dtype, endpoint, data):
538538 raise
539539
540540
541- @given (dtype = hh .numeric_dtypes , data = st .data ())
542- def test_meshgrid (dtype , data ):
541+ @given (
542+ dtype = hh .numeric_dtypes ,
543+ kw = hh .kwargs (indexing = st .sampled_from (["xy" , "ij" ])),
544+ data = st .data ()
545+ )
546+ def test_meshgrid (dtype , kw , data ):
543547 # The number and size of generated arrays is arbitrarily limited to prevent
544548 # meshgrid() running out of memory.
545549 shapes = data .draw (
@@ -557,11 +561,18 @@ def test_meshgrid(dtype, data):
557561 # sanity check
558562 # assert math.prod(math.prod(x.shape) for x in arrays) <= hh.MAX_ARRAY_SIZE
559563
560- repro_snippet = ph .format_snippet (f"xp.meshgrid(*arrays) with { arrays = } " )
564+ tgt_shape = [a .shape [0 ] for a in arrays ]
565+ if len (tgt_shape ) > 1 and kw .get ('indexing' , 'xy' ) == 'xy' :
566+ tgt_shape [0 ], tgt_shape [1 ] = tgt_shape [1 ], tgt_shape [0 ]
567+ tgt_shape = tuple (tgt_shape )
568+
569+ repro_snippet = ph .format_snippet (f"xp.meshgrid(*arrays, **kw) with { arrays = } and { kw = } " )
561570 try :
562- out = xp .meshgrid (* arrays )
571+ out = xp .meshgrid (* arrays , ** kw )
563572 for i , x in enumerate (out ):
564573 ph .assert_dtype ("meshgrid" , in_dtype = dtype , out_dtype = x .dtype , repr_name = f"out[{ i } ].dtype" )
574+ ph .assert_shape ("meshgrid" , out_shape = x .shape , expected = tgt_shape )
575+
565576 except Exception as exc :
566577 ph .add_note (exc , repro_snippet )
567578 raise
0 commit comments