Skip to content

Commit ca5ba5b

Browse files
Add TestMgrid
1 parent 9cbf764 commit ca5ba5b

File tree

1 file changed

+35
-0
lines changed

1 file changed

+35
-0
lines changed

dpnp/tests/test_arraycreation.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -968,6 +968,41 @@ def test_meshgrid_raise_error():
968968
dpnp.meshgrid(b, indexing="ab")
969969

970970

971+
class TestMgrid:
972+
def check_results(self, result, expected):
973+
if isinstance(result, (list, tuple)):
974+
assert len(result) == len(expected)
975+
for dp_arr, np_arr in zip(result, expected):
976+
assert_allclose(dp_arr, np_arr)
977+
else:
978+
assert_allclose(result, expected)
979+
980+
@pytest.mark.parametrize(
981+
"slice",
982+
[
983+
slice(0, 5, 0.5), # float step
984+
slice(0, 5, 5j), # complex step
985+
],
986+
)
987+
def test_single_slice(self, slice):
988+
dpnp_result = dpnp.mgrid[slice]
989+
numpy_result = numpy.mgrid[slice]
990+
self.check_results(dpnp_result, numpy_result)
991+
992+
@pytest.mark.parametrize(
993+
"slices",
994+
[
995+
(slice(None, 5, 1), slice(None, 10, 2)), # no start
996+
(slice(0, 5), slice(0, 10)), # no step
997+
(slice(0, 5.5, 1), slice(0, 10, 3j)), # float stop and complex step
998+
],
999+
)
1000+
def test_md_slice(self, slices):
1001+
dpnp_result = dpnp.mgrid[slices]
1002+
numpy_result = numpy.mgrid[slices]
1003+
self.check_results(dpnp_result, numpy_result)
1004+
1005+
9711006
def test_exception_tri():
9721007
x = dpnp.ones((2, 2))
9731008
with pytest.raises(TypeError):

0 commit comments

Comments
 (0)