Skip to content

Commit 81ce954

Browse files
author
Vahid Tavanashad
committed
address comment
1 parent dda7d2c commit 81ce954

File tree

2 files changed

+48
-74
lines changed

2 files changed

+48
-74
lines changed

dpnp/tests/test_sycl_queue.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1044,19 +1044,18 @@ def test_concat_stack(func, data1, data2, device):
10441044
assert_sycl_queue_equal(result.sycl_queue, x2.sycl_queue)
10451045

10461046

1047+
@pytest.mark.parametrize("device", valid_dev, ids=dev_ids)
10471048
class TestDelete:
10481049
@pytest.mark.parametrize(
10491050
"obj",
10501051
[slice(None, None, 2), 3, [2, 3]],
10511052
ids=["slice", "scalar", "list"],
10521053
)
1053-
@pytest.mark.parametrize("device", valid_dev, ids=dev_ids)
10541054
def test_delete(self, obj, device):
10551055
x = dpnp.arange(5, device=device)
10561056
result = dpnp.delete(x, obj)
10571057
assert_sycl_queue_equal(result.sycl_queue, x.sycl_queue)
10581058

1059-
@pytest.mark.parametrize("device", valid_dev, ids=dev_ids)
10601059
def test_obj_ndarray(self, device):
10611060
x = dpnp.arange(5, device=device)
10621061
y = dpnp.array([1, 4], device=device)
@@ -1066,13 +1065,13 @@ def test_obj_ndarray(self, device):
10661065
assert_sycl_queue_equal(result.sycl_queue, y.sycl_queue)
10671066

10681067

1068+
@pytest.mark.parametrize("device", valid_dev, ids=dev_ids)
10691069
class TestInsert:
10701070
@pytest.mark.parametrize(
10711071
"obj",
10721072
[slice(None, None, 2), 3, [2, 3]],
10731073
ids=["slice", "scalar", "list"],
10741074
)
1075-
@pytest.mark.parametrize("device", valid_dev, ids=dev_ids)
10761075
def test_basic(self, obj, device):
10771076
x = dpnp.arange(5, device=device)
10781077
result = dpnp.insert(x, obj, 3)
@@ -1083,7 +1082,6 @@ def test_basic(self, obj, device):
10831082
[slice(None, None, 3), 3, [2, 3]],
10841083
ids=["slice", "scalar", "list"],
10851084
)
1086-
@pytest.mark.parametrize("device", valid_dev, ids=dev_ids)
10871085
def test_values_ndarray(self, obj, device):
10881086
x = dpnp.arange(5, device=device)
10891087
y = dpnp.array([1, 4], device=device)
@@ -1093,7 +1091,6 @@ def test_values_ndarray(self, obj, device):
10931091
assert_sycl_queue_equal(result.sycl_queue, y.sycl_queue)
10941092

10951093
@pytest.mark.parametrize("values", [-2, [-1, -2]], ids=["scalar", "list"])
1096-
@pytest.mark.parametrize("device", valid_dev, ids=dev_ids)
10971094
def test_obj_ndarray(self, values, device):
10981095
x = dpnp.arange(5, device=device)
10991096
y = dpnp.array([1, 4], device=device)
@@ -1102,7 +1099,6 @@ def test_obj_ndarray(self, values, device):
11021099
assert_sycl_queue_equal(result.sycl_queue, x.sycl_queue)
11031100
assert_sycl_queue_equal(result.sycl_queue, y.sycl_queue)
11041101

1105-
@pytest.mark.parametrize("device", valid_dev, ids=dev_ids)
11061102
def test_obj_values_ndarray(self, device):
11071103
x = dpnp.arange(5, device=device)
11081104
y = dpnp.array([1, 4], device=device)

dpnp/tests/test_usm_type.py

Lines changed: 46 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -788,30 +788,29 @@ def test_split(func, data1, usm_type):
788788
assert y[1].usm_type == usm_type
789789

790790

791+
@pytest.mark.parametrize("usm_type", list_of_usm_types)
791792
class TestDelete:
792793
@pytest.mark.parametrize(
793794
"obj",
794795
[slice(None, None, 2), 3, [2, 3]],
795796
ids=["slice", "scalar", "list"],
796797
)
797-
@pytest.mark.parametrize("usm_type", list_of_usm_types)
798798
def test_delete(self, obj, usm_type):
799799
x = dpnp.arange(5, usm_type=usm_type)
800800
result = dpnp.delete(x, obj)
801801

802802
assert x.usm_type == usm_type
803803
assert result.usm_type == usm_type
804804

805-
@pytest.mark.parametrize("usm_type_x", list_of_usm_types)
806-
@pytest.mark.parametrize("usm_type_y", list_of_usm_types)
807-
def test_obj_ndarray(self, usm_type_x, usm_type_y):
808-
x = dpnp.arange(5, usm_type=usm_type_x)
809-
y = dpnp.array([1, 4], usm_type=usm_type_y)
805+
@pytest.mark.parametrize("usm_type_other", list_of_usm_types)
806+
def test_obj_ndarray(self, usm_type, usm_type_other):
807+
x = dpnp.arange(5, usm_type=usm_type)
808+
y = dpnp.array([1, 4], usm_type=usm_type_other)
810809
z = dpnp.delete(x, y)
811810

812-
assert x.usm_type == usm_type_x
813-
assert y.usm_type == usm_type_y
814-
assert z.usm_type == du.get_coerced_usm_type([usm_type_x, usm_type_y])
811+
assert x.usm_type == usm_type
812+
assert y.usm_type == usm_type_other
813+
assert z.usm_type == du.get_coerced_usm_type([usm_type, usm_type_other])
815814

816815

817816
@pytest.mark.parametrize("usm_type", list_of_usm_types)
@@ -829,8 +828,8 @@ def test_einsum(usm_type):
829828
assert result.usm_type == usm_type
830829

831830

831+
@pytest.mark.parametrize("usm_type", list_of_usm_types)
832832
class TestInsert:
833-
@pytest.mark.parametrize("usm_type", list_of_usm_types)
834833
@pytest.mark.parametrize(
835834
"obj",
836835
[slice(None, None, 2), 3, [2, 3]],
@@ -848,43 +847,40 @@ def test_bacis(self, usm_type, obj):
848847
[slice(None, None, 3), 3, [2, 3]],
849848
ids=["slice", "scalar", "list"],
850849
)
851-
@pytest.mark.parametrize("usm_type_x", list_of_usm_types)
852-
@pytest.mark.parametrize("usm_type_y", list_of_usm_types)
853-
def test_values_ndarray(self, obj, usm_type_x, usm_type_y):
854-
x = dpnp.arange(5, usm_type=usm_type_x)
855-
y = dpnp.array([1, 4], usm_type=usm_type_y)
850+
@pytest.mark.parametrize("usm_type_other", list_of_usm_types)
851+
def test_values_ndarray(self, obj, usm_type, usm_type_other):
852+
x = dpnp.arange(5, usm_type=usm_type)
853+
y = dpnp.array([1, 4], usm_type=usm_type_other)
856854
z = dpnp.insert(x, obj, y)
857855

858-
assert x.usm_type == usm_type_x
859-
assert y.usm_type == usm_type_y
860-
assert z.usm_type == du.get_coerced_usm_type([usm_type_x, usm_type_y])
856+
assert x.usm_type == usm_type
857+
assert y.usm_type == usm_type_other
858+
assert z.usm_type == du.get_coerced_usm_type([usm_type, usm_type_other])
861859

862860
@pytest.mark.parametrize("values", [-2, [-1, -2]], ids=["scalar", "list"])
863-
@pytest.mark.parametrize("usm_type_x", list_of_usm_types)
864-
@pytest.mark.parametrize("usm_type_y", list_of_usm_types)
865-
def test_obj_ndarray(self, values, usm_type_x, usm_type_y):
866-
x = dpnp.arange(5, usm_type=usm_type_x)
867-
y = dpnp.array([1, 4], usm_type=usm_type_y)
861+
@pytest.mark.parametrize("usm_type_other", list_of_usm_types)
862+
def test_obj_ndarray(self, values, usm_type, usm_type_other):
863+
x = dpnp.arange(5, usm_type=usm_type)
864+
y = dpnp.array([1, 4], usm_type=usm_type_other)
868865
z = dpnp.insert(x, y, values)
869866

870-
assert x.usm_type == usm_type_x
871-
assert y.usm_type == usm_type_y
872-
assert z.usm_type == du.get_coerced_usm_type([usm_type_x, usm_type_y])
867+
assert x.usm_type == usm_type
868+
assert y.usm_type == usm_type_other
869+
assert z.usm_type == du.get_coerced_usm_type([usm_type, usm_type_other])
873870

874-
@pytest.mark.parametrize("usm_type_x", list_of_usm_types)
875871
@pytest.mark.parametrize("usm_type_y", list_of_usm_types)
876872
@pytest.mark.parametrize("usm_type_z", list_of_usm_types)
877-
def test_obj_values_ndarray(self, usm_type_x, usm_type_y, usm_type_z):
878-
x = dpnp.arange(5, usm_type=usm_type_x)
873+
def test_obj_values_ndarray(self, usm_type, usm_type_y, usm_type_z):
874+
x = dpnp.arange(5, usm_type=usm_type)
879875
y = dpnp.array([1, 4], usm_type=usm_type_y)
880876
z = dpnp.array([-1, -3], usm_type=usm_type_z)
881877
res = dpnp.insert(x, y, z)
882878

883-
assert x.usm_type == usm_type_x
879+
assert x.usm_type == usm_type
884880
assert y.usm_type == usm_type_y
885881
assert z.usm_type == usm_type_z
886882
assert res.usm_type == du.get_coerced_usm_type(
887-
[usm_type_x, usm_type_y, usm_type_z]
883+
[usm_type, usm_type_y, usm_type_z]
888884
)
889885

890886

@@ -1258,6 +1254,7 @@ def test_choose(usm_type_x, usm_type_ind):
12581254
assert z.usm_type == du.get_coerced_usm_type([usm_type_x, usm_type_ind])
12591255

12601256

1257+
@pytest.mark.parametrize("usm_type", list_of_usm_types)
12611258
class TestLinAlgebra:
12621259
@pytest.mark.parametrize(
12631260
"data, is_empty",
@@ -1269,7 +1266,6 @@ class TestLinAlgebra:
12691266
],
12701267
ids=["2D", "3D", "Empty_2D", "Empty_3D"],
12711268
)
1272-
@pytest.mark.parametrize("usm_type", list_of_usm_types)
12731269
def test_cholesky(self, data, is_empty, usm_type):
12741270
dtype = dpnp.default_float_type()
12751271
if is_empty:
@@ -1280,7 +1276,6 @@ def test_cholesky(self, data, is_empty, usm_type):
12801276
result = dpnp.linalg.cholesky(x)
12811277
assert x.usm_type == result.usm_type
12821278

1283-
@pytest.mark.parametrize("usm_type", list_of_usm_types)
12841279
@pytest.mark.parametrize(
12851280
"p", [None, -dpnp.inf, -2, -1, 1, 2, dpnp.inf, "fro"]
12861281
)
@@ -1292,7 +1287,6 @@ def test_cond(self, usm_type, p):
12921287
assert ia.usm_type == usm_type
12931288
assert result.usm_type == usm_type
12941289

1295-
@pytest.mark.parametrize("usm_type", list_of_usm_types)
12961290
@pytest.mark.parametrize(
12971291
"shape, is_empty",
12981292
[
@@ -1328,7 +1322,6 @@ def test_det(self, shape, is_empty, usm_type):
13281322
[(4, 4), (0, 0), (2, 3, 3), (0, 2, 2), (1, 0, 0)],
13291323
ids=["(4, 4)", "(0, 0)", "(2, 3, 3)", "(0, 2, 2)", "(1, 0, 0)"],
13301324
)
1331-
@pytest.mark.parametrize("usm_type", list_of_usm_types)
13321325
def test_eigenvalue(self, func, shape, usm_type):
13331326
# Set a `hermitian` flag for generate_random_numpy_array() to
13341327
# get a symmetric array for eigh() and eigvalsh() or
@@ -1346,7 +1339,6 @@ def test_eigenvalue(self, func, shape, usm_type):
13461339

13471340
assert a.usm_type == dp_val.usm_type
13481341

1349-
@pytest.mark.parametrize("usm_type", list_of_usm_types)
13501342
@pytest.mark.parametrize(
13511343
"shape, is_empty",
13521344
[
@@ -1371,26 +1363,24 @@ def test_inv(self, shape, is_empty, usm_type):
13711363

13721364
assert x.usm_type == result.usm_type
13731365

1374-
@pytest.mark.parametrize("usm_type_a", list_of_usm_types)
1375-
@pytest.mark.parametrize("usm_type_b", list_of_usm_types)
1366+
@pytest.mark.parametrize("usm_type_other", list_of_usm_types)
13761367
@pytest.mark.parametrize(
13771368
["m", "n", "nrhs"],
13781369
[(4, 2, 2), (4, 0, 1), (4, 2, 0), (0, 0, 0)],
13791370
)
1380-
def test_lstsq(self, m, n, nrhs, usm_type_a, usm_type_b):
1381-
a = dpnp.arange(m * n, usm_type=usm_type_a).reshape(m, n)
1382-
b = dpnp.ones((m, nrhs), usm_type=usm_type_b)
1371+
def test_lstsq(self, m, n, nrhs, usm_type, usm_type_other):
1372+
a = dpnp.arange(m * n, usm_type=usm_type).reshape(m, n)
1373+
b = dpnp.ones((m, nrhs), usm_type=usm_type_other)
13831374
result = dpnp.linalg.lstsq(a, b)
13841375

1385-
assert a.usm_type == usm_type_a
1386-
assert b.usm_type == usm_type_b
1376+
assert a.usm_type == usm_type
1377+
assert b.usm_type == usm_type_other
13871378
for param in result:
13881379
assert param.usm_type == du.get_coerced_usm_type(
1389-
[usm_type_a, usm_type_b]
1380+
[usm_type, usm_type_other]
13901381
)
13911382

13921383
@pytest.mark.parametrize("n", [-1, 0, 1, 2, 3])
1393-
@pytest.mark.parametrize("usm_type", list_of_usm_types)
13941384
def test_matrix_power(self, n, usm_type):
13951385
a = dpnp.array([[1, 2], [3, 5]], usm_type=usm_type)
13961386

@@ -1406,14 +1396,12 @@ def test_matrix_power(self, n, usm_type):
14061396
],
14071397
ids=["1-D array", "2-D array no tol", "2_d array with tol"],
14081398
)
1409-
@pytest.mark.parametrize("usm_type", list_of_usm_types)
14101399
def test_matrix_rank(self, data, tol, usm_type):
14111400
a = dpnp.array(data, usm_type=usm_type)
14121401

14131402
result = dpnp.linalg.matrix_rank(a, tol=tol)
14141403
assert a.usm_type == result.usm_type
14151404

1416-
@pytest.mark.parametrize("usm_type", list_of_usm_types)
14171405
def test_multi_dot(self, usm_type):
14181406
array_list = []
14191407
for num_array in [3, 5]: # number of arrays in multi_dot
@@ -1427,7 +1415,6 @@ def test_multi_dot(self, usm_type):
14271415
assert input_usm_type == usm_type
14281416
assert result.usm_type == usm_type
14291417

1430-
@pytest.mark.parametrize("usm_type", list_of_usm_types)
14311418
@pytest.mark.parametrize(
14321419
"ord", [None, -dpnp.inf, -2, -1, 1, 2, 3, dpnp.inf, "fro", "nuc"]
14331420
)
@@ -1449,7 +1436,6 @@ def test_norm(self, usm_type, ord, axis):
14491436
assert ia.usm_type == usm_type
14501437
assert result.usm_type == usm_type
14511438

1452-
@pytest.mark.parametrize("usm_type", list_of_usm_types)
14531439
@pytest.mark.parametrize(
14541440
"shape, hermitian",
14551441
[
@@ -1476,7 +1462,6 @@ def test_pinv(self, shape, hermitian, usm_type):
14761462
result = dpnp.linalg.pinv(a, hermitian=hermitian)
14771463
assert a.usm_type == result.usm_type
14781464

1479-
@pytest.mark.parametrize("usm_type", list_of_usm_types)
14801465
@pytest.mark.parametrize(
14811466
"shape",
14821467
[(4, 4), (2, 0), (2, 2, 3), (0, 2, 3), (1, 0, 3)],
@@ -1496,7 +1481,6 @@ def test_qr(self, shape, mode, usm_type):
14961481
assert a.usm_type == dp_q.usm_type
14971482
assert a.usm_type == dp_r.usm_type
14981483

1499-
@pytest.mark.parametrize("usm_type", list_of_usm_types)
15001484
@pytest.mark.parametrize(
15011485
"shape, is_empty",
15021486
[
@@ -1522,7 +1506,6 @@ def test_slogdet(self, shape, is_empty, usm_type):
15221506
assert x.usm_type == sign.usm_type
15231507
assert x.usm_type == logdet.usm_type
15241508

1525-
@pytest.mark.parametrize("usm_type_matrix", list_of_usm_types)
15261509
@pytest.mark.parametrize("usm_type_rhs", list_of_usm_types)
15271510
@pytest.mark.parametrize(
15281511
"matrix, rhs",
@@ -1546,18 +1529,15 @@ def test_slogdet(self, shape, is_empty, usm_type):
15461529
"3D_Matrix_and_3D_RHS",
15471530
],
15481531
)
1549-
def test_solve(self, matrix, rhs, usm_type_matrix, usm_type_rhs):
1550-
x = dpnp.array(matrix, usm_type=usm_type_matrix)
1532+
def test_solve(self, matrix, rhs, usm_type, usm_type_rhs):
1533+
x = dpnp.array(matrix, usm_type=usm_type)
15511534
y = dpnp.array(rhs, usm_type=usm_type_rhs)
15521535
z = dpnp.linalg.solve(x, y)
15531536

1554-
assert x.usm_type == usm_type_matrix
1537+
assert x.usm_type == usm_type
15551538
assert y.usm_type == usm_type_rhs
1556-
assert z.usm_type == du.get_coerced_usm_type(
1557-
[usm_type_matrix, usm_type_rhs]
1558-
)
1539+
assert z.usm_type == du.get_coerced_usm_type([usm_type, usm_type_rhs])
15591540

1560-
@pytest.mark.parametrize("usm_type", list_of_usm_types)
15611541
@pytest.mark.parametrize("full_matrices_param", [True, False])
15621542
@pytest.mark.parametrize("compute_uv_param", [True, False])
15631543
@pytest.mark.parametrize(
@@ -1606,24 +1586,22 @@ def test_svd(self, usm_type, shape, full_matrices_param, compute_uv_param):
16061586

16071587
assert x.usm_type == s.usm_type
16081588

1609-
@pytest.mark.parametrize("usm_type", list_of_usm_types)
16101589
def test_tensorinv(self, usm_type):
16111590
a = dpnp.eye(12, usm_type=usm_type).reshape(12, 4, 3)
16121591
ainv = dpnp.linalg.tensorinv(a, ind=1)
16131592

16141593
assert a.usm_type == ainv.usm_type
16151594

1616-
@pytest.mark.parametrize("usm_type_a", list_of_usm_types)
1617-
@pytest.mark.parametrize("usm_type_b", list_of_usm_types)
1618-
def test_tensorsolve(self, usm_type_a, usm_type_b):
1595+
@pytest.mark.parametrize("usm_type_other", list_of_usm_types)
1596+
def test_tensorsolve(self, usm_type, usm_type_other):
16191597
data = numpy.random.randn(3, 2, 6)
1620-
a = dpnp.array(data, usm_type=usm_type_a)
1621-
b = dpnp.ones(a.shape[:2], dtype=a.dtype, usm_type=usm_type_b)
1598+
a = dpnp.array(data, usm_type=usm_type)
1599+
b = dpnp.ones(a.shape[:2], dtype=a.dtype, usm_type=usm_type_other)
16221600

16231601
result = dpnp.linalg.tensorsolve(a, b)
16241602

1625-
assert a.usm_type == usm_type_a
1626-
assert b.usm_type == usm_type_b
1603+
assert a.usm_type == usm_type
1604+
assert b.usm_type == usm_type_other
16271605
assert result.usm_type == du.get_coerced_usm_type(
1628-
[usm_type_a, usm_type_b]
1606+
[usm_type, usm_type_other]
16291607
)

0 commit comments

Comments
 (0)