Skip to content

Commit d99f44b

Browse files
author
Vahid Tavanashad
committed
update tests for search-special-statistics-umath
1 parent e24fa99 commit d99f44b

15 files changed

+222
-268
lines changed

dpnp/tests/test_amin_amax.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def test_amax_amin(func, keepdims, dtype):
2222
for axis in range(len(a)):
2323
result = getattr(dpnp, func)(ia, axis=axis, keepdims=keepdims)
2424
expected = getattr(numpy, func)(a, axis=axis, keepdims=keepdims)
25-
assert_allclose(expected, result)
25+
assert_allclose(result, expected)
2626

2727

2828
def _get_min_max_input(type, shape):

dpnp/tests/test_arraycreation.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ def test_arange(start, stop, step, dtype):
216216
func = lambda xp: xp.arange(start, stop=stop, step=step, dtype=dtype)
217217

218218
exp_array = func(numpy)
219-
res_array = func(dpnp).asnumpy()
219+
res_array = func(dpnp)
220220

221221
if dtype is None:
222222
_device = dpctl.SyclQueue().sycl_device
@@ -234,7 +234,7 @@ def test_arange(start, stop, step, dtype):
234234
_dtype, dpnp.complexfloating
235235
):
236236
assert_allclose(
237-
exp_array, res_array, rtol=rtol_mult * numpy.finfo(_dtype).eps
237+
res_array, exp_array, rtol=rtol_mult * numpy.finfo(_dtype).eps
238238
)
239239
else:
240240
assert_array_equal(exp_array, res_array)
@@ -540,7 +540,7 @@ def test_vander(array, dtype, n, increase):
540540
a_np = numpy.array(array, dtype=dtype)
541541
a_dpnp = dpnp.array(array, dtype=dtype)
542542

543-
assert_allclose(vander_func(numpy, a_np), vander_func(dpnp, a_dpnp))
543+
assert_allclose(vander_func(dpnp, a_dpnp), vander_func(numpy, a_np))
544544

545545

546546
def test_vander_raise_error():
@@ -560,7 +560,7 @@ def test_vander_raise_error():
560560
)
561561
def test_vander_seq(sequence):
562562
vander_func = lambda xp, x: xp.vander(x)
563-
assert_allclose(vander_func(numpy, sequence), vander_func(dpnp, sequence))
563+
assert_allclose(vander_func(dpnp, sequence), vander_func(numpy, sequence))
564564

565565

566566
@pytest.mark.usefixtures("suppress_complex_warning")
@@ -607,19 +607,19 @@ def test_full_order(order1, order2):
607607

608608
assert ia.flags.c_contiguous == a.flags.c_contiguous
609609
assert ia.flags.f_contiguous == a.flags.f_contiguous
610-
assert numpy.array_equal(dpnp.asnumpy(ia), a)
610+
assert_equal(ia, a)
611611

612612

613613
def test_full_strides():
614614
a = numpy.full((3, 3), numpy.arange(3, dtype="i4"))
615615
ia = dpnp.full((3, 3), dpnp.arange(3, dtype="i4"))
616616
assert ia.strides == tuple(el // a.itemsize for el in a.strides)
617-
assert_array_equal(dpnp.asnumpy(ia), a)
617+
assert_array_equal(ia, a)
618618

619619
a = numpy.full((3, 3), numpy.arange(6, dtype="i4")[::2])
620620
ia = dpnp.full((3, 3), dpnp.arange(6, dtype="i4")[::2])
621621
assert ia.strides == tuple(el // a.itemsize for el in a.strides)
622-
assert_array_equal(dpnp.asnumpy(ia), a)
622+
assert_array_equal(ia, a)
623623

624624

625625
@pytest.mark.parametrize(
@@ -891,9 +891,9 @@ def test_geomspace(sign, dtype, num, endpoint):
891891
dpnp_res = func(dpnp)
892892

893893
if dtype in [numpy.int64, numpy.int32]:
894-
assert_allclose(dpnp_res, np_res, rtol=1)
894+
assert_allclose(dpnp_res, np_res)
895895
else:
896-
assert_allclose(dpnp_res, np_res, rtol=1e-04)
896+
assert_allclose(dpnp_res, np_res)
897897

898898

899899
@pytest.mark.parametrize("start", [1j, 1 + 1j])
@@ -902,22 +902,22 @@ def test_geomspace_complex(start, stop):
902902
func = lambda xp: xp.geomspace(start, stop, num=10)
903903
np_res = func(numpy)
904904
dpnp_res = func(dpnp)
905-
assert_allclose(dpnp_res, np_res, rtol=1e-04)
905+
assert_allclose(dpnp_res, np_res)
906906

907907

908908
@pytest.mark.parametrize("axis", [0, 1])
909909
def test_geomspace_axis(axis):
910910
func = lambda xp: xp.geomspace([2, 3], [20, 15], num=10, axis=axis)
911911
np_res = func(numpy)
912912
dpnp_res = func(dpnp)
913-
assert_allclose(dpnp_res, np_res, rtol=1e-04)
913+
assert_allclose(dpnp_res, np_res)
914914

915915

916916
def test_geomspace_num0():
917917
func = lambda xp: xp.geomspace(1, 10, num=0, endpoint=False)
918918
np_res = func(numpy)
919919
dpnp_res = func(dpnp)
920-
assert_allclose(dpnp_res, np_res, rtol=1e-04)
920+
assert_allclose(dpnp_res, np_res)
921921

922922

923923
@pytest.mark.parametrize("dtype", get_all_dtypes())
@@ -936,9 +936,9 @@ def test_logspace(dtype, num, endpoint):
936936
dpnp_res = func(dpnp)
937937

938938
if dtype in [numpy.int64, numpy.int32]:
939-
assert_allclose(dpnp_res, np_res, rtol=1)
939+
assert_allclose(dpnp_res, np_res)
940940
else:
941-
assert_allclose(dpnp_res, np_res, rtol=1e-04)
941+
assert_allclose(dpnp_res, np_res)
942942

943943

944944
@pytest.mark.parametrize("axis", [0, 1])

0 commit comments

Comments
 (0)