Skip to content

Commit 49bb99a

Browse files
author
Vahid Tavanashad
committed
update tests
1 parent e24fa99 commit 49bb99a

15 files changed

+222
-282
lines changed

dpnp/tests/test_amin_amax.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def test_amax_amin(func, keepdims, dtype):
2222
for axis in range(len(a)):
2323
result = getattr(dpnp, func)(ia, axis=axis, keepdims=keepdims)
2424
expected = getattr(numpy, func)(a, axis=axis, keepdims=keepdims)
25-
assert_allclose(expected, result)
25+
assert_allclose(result, expected)
2626

2727

2828
def _get_min_max_input(type, shape):

dpnp/tests/test_arraycreation.py

Lines changed: 15 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ def test_arange(start, stop, step, dtype):
216216
func = lambda xp: xp.arange(start, stop=stop, step=step, dtype=dtype)
217217

218218
exp_array = func(numpy)
219-
res_array = func(dpnp).asnumpy()
219+
res_array = func(dpnp)
220220

221221
if dtype is None:
222222
_device = dpctl.SyclQueue().sycl_device
@@ -234,7 +234,7 @@ def test_arange(start, stop, step, dtype):
234234
_dtype, dpnp.complexfloating
235235
):
236236
assert_allclose(
237-
exp_array, res_array, rtol=rtol_mult * numpy.finfo(_dtype).eps
237+
res_array, exp_array, rtol=rtol_mult * numpy.finfo(_dtype).eps
238238
)
239239
else:
240240
assert_array_equal(exp_array, res_array)
@@ -540,7 +540,7 @@ def test_vander(array, dtype, n, increase):
540540
a_np = numpy.array(array, dtype=dtype)
541541
a_dpnp = dpnp.array(array, dtype=dtype)
542542

543-
assert_allclose(vander_func(numpy, a_np), vander_func(dpnp, a_dpnp))
543+
assert_allclose(vander_func(dpnp, a_dpnp), vander_func(numpy, a_np))
544544

545545

546546
def test_vander_raise_error():
@@ -560,7 +560,7 @@ def test_vander_raise_error():
560560
)
561561
def test_vander_seq(sequence):
562562
vander_func = lambda xp, x: xp.vander(x)
563-
assert_allclose(vander_func(numpy, sequence), vander_func(dpnp, sequence))
563+
assert_allclose(vander_func(dpnp, sequence), vander_func(numpy, sequence))
564564

565565

566566
@pytest.mark.usefixtures("suppress_complex_warning")
@@ -607,19 +607,19 @@ def test_full_order(order1, order2):
607607

608608
assert ia.flags.c_contiguous == a.flags.c_contiguous
609609
assert ia.flags.f_contiguous == a.flags.f_contiguous
610-
assert numpy.array_equal(dpnp.asnumpy(ia), a)
610+
assert_equal(ia, a)
611611

612612

613613
def test_full_strides():
614614
a = numpy.full((3, 3), numpy.arange(3, dtype="i4"))
615615
ia = dpnp.full((3, 3), dpnp.arange(3, dtype="i4"))
616616
assert ia.strides == tuple(el // a.itemsize for el in a.strides)
617-
assert_array_equal(dpnp.asnumpy(ia), a)
617+
assert_array_equal(ia, a)
618618

619619
a = numpy.full((3, 3), numpy.arange(6, dtype="i4")[::2])
620620
ia = dpnp.full((3, 3), dpnp.arange(6, dtype="i4")[::2])
621621
assert ia.strides == tuple(el // a.itemsize for el in a.strides)
622-
assert_array_equal(dpnp.asnumpy(ia), a)
622+
assert_array_equal(ia, a)
623623

624624

625625
@pytest.mark.parametrize(
@@ -762,20 +762,12 @@ def test_linspace(start, stop, num, dtype, retstep):
762762
assert_dtype_allclose(res_dp, res_np)
763763

764764

765+
@pytest.mark.parametrize("func", ["geomspace", "linspace", "logspace"])
765766
@pytest.mark.parametrize(
766-
"func",
767-
["geomspace", "linspace", "logspace"],
768-
ids=["geomspace", "linspace", "logspace"],
767+
"start_dtype", [numpy.float64, numpy.float32, numpy.int64, numpy.int32]
769768
)
770769
@pytest.mark.parametrize(
771-
"start_dtype",
772-
[numpy.float64, numpy.float32, numpy.int64, numpy.int32],
773-
ids=["float64", "float32", "int64", "int32"],
774-
)
775-
@pytest.mark.parametrize(
776-
"stop_dtype",
777-
[numpy.float64, numpy.float32, numpy.int64, numpy.int32],
778-
ids=["float64", "float32", "int64", "int32"],
770+
"stop_dtype", [numpy.float64, numpy.float32, numpy.int64, numpy.int32]
779771
)
780772
def test_space_numpy_dtype(func, start_dtype, stop_dtype):
781773
start = numpy.array([1, 2, 3], dtype=start_dtype)
@@ -890,10 +882,7 @@ def test_geomspace(sign, dtype, num, endpoint):
890882
np_res = func(numpy)
891883
dpnp_res = func(dpnp)
892884

893-
if dtype in [numpy.int64, numpy.int32]:
894-
assert_allclose(dpnp_res, np_res, rtol=1)
895-
else:
896-
assert_allclose(dpnp_res, np_res, rtol=1e-04)
885+
assert_allclose(dpnp_res, np_res)
897886

898887

899888
@pytest.mark.parametrize("start", [1j, 1 + 1j])
@@ -902,22 +891,22 @@ def test_geomspace_complex(start, stop):
902891
func = lambda xp: xp.geomspace(start, stop, num=10)
903892
np_res = func(numpy)
904893
dpnp_res = func(dpnp)
905-
assert_allclose(dpnp_res, np_res, rtol=1e-04)
894+
assert_allclose(dpnp_res, np_res)
906895

907896

908897
@pytest.mark.parametrize("axis", [0, 1])
909898
def test_geomspace_axis(axis):
910899
func = lambda xp: xp.geomspace([2, 3], [20, 15], num=10, axis=axis)
911900
np_res = func(numpy)
912901
dpnp_res = func(dpnp)
913-
assert_allclose(dpnp_res, np_res, rtol=1e-04)
902+
assert_allclose(dpnp_res, np_res)
914903

915904

916905
def test_geomspace_num0():
917906
func = lambda xp: xp.geomspace(1, 10, num=0, endpoint=False)
918907
np_res = func(numpy)
919908
dpnp_res = func(dpnp)
920-
assert_allclose(dpnp_res, np_res, rtol=1e-04)
909+
assert_allclose(dpnp_res, np_res)
921910

922911

923912
@pytest.mark.parametrize("dtype", get_all_dtypes())
@@ -935,10 +924,7 @@ def test_logspace(dtype, num, endpoint):
935924
np_res = func(numpy)
936925
dpnp_res = func(dpnp)
937926

938-
if dtype in [numpy.int64, numpy.int32]:
939-
assert_allclose(dpnp_res, np_res, rtol=1)
940-
else:
941-
assert_allclose(dpnp_res, np_res, rtol=1e-04)
927+
assert_allclose(dpnp_res, np_res)
942928

943929

944930
@pytest.mark.parametrize("axis", [0, 1])

0 commit comments

Comments
 (0)