@@ -968,6 +968,41 @@ def test_meshgrid_raise_error():
968968 dpnp .meshgrid (b , indexing = "ab" )
969969
970970
971+ class TestMgrid :
972+ def check_results (self , result , expected ):
973+ if isinstance (result , (list , tuple )):
974+ assert len (result ) == len (expected )
975+ for dp_arr , np_arr in zip (result , expected ):
976+ assert_allclose (dp_arr , np_arr )
977+ else :
978+ assert_allclose (result , expected )
979+
980+ @pytest .mark .parametrize (
981+ "slice" ,
982+ [
983+ slice (0 , 5 , 0.5 ), # float step
984+ slice (0 , 5 , 5j ), # complex step
985+ ],
986+ )
987+ def test_single_slice (self , slice ):
988+ dpnp_result = dpnp .mgrid [slice ]
989+ numpy_result = numpy .mgrid [slice ]
990+ self .check_results (dpnp_result , numpy_result )
991+
992+ @pytest .mark .parametrize (
993+ "slices" ,
994+ [
995+ (slice (None , 5 , 1 ), slice (None , 10 , 2 )), # no start
996+ (slice (0 , 5 ), slice (0 , 10 )), # no step
997+ (slice (0 , 5.5 , 1 ), slice (0 , 10 , 3j )), # float stop and complex step
998+ ],
999+ )
1000+ def test_md_slice (self , slices ):
1001+ dpnp_result = dpnp .mgrid [slices ]
1002+ numpy_result = numpy .mgrid [slices ]
1003+ self .check_results (dpnp_result , numpy_result )
1004+
1005+
9711006def test_exception_tri ():
9721007 x = dpnp .ones ((2 , 2 ))
9731008 with pytest .raises (TypeError ):
0 commit comments