Skip to content

Commit d140f56

Browse files
Update test_solve
1 parent 778d1ca commit d140f56

File tree

2 files changed

+28
-32
lines changed

2 files changed

+28
-32
lines changed

dpnp/tests/test_sycl_queue.py

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2392,40 +2392,34 @@ def test_where(device):
23922392
ids=[device.filter_string for device in valid_devices],
23932393
)
23942394
@pytest.mark.parametrize(
2395-
"matrix, vector",
2395+
"matrix, rhs",
23962396
[
23972397
([[1, 2], [3, 5]], numpy.empty((2, 0))),
23982398
([[1, 2], [3, 5]], [1, 2]),
23992399
(
24002400
[
2401-
[[1, 1, 1], [0, 2, 5], [2, 5, -1]],
2402-
[[3, -1, 1], [1, 2, 3], [2, 3, 1]],
2403-
[[1, 4, 1], [1, 2, -2], [4, 1, 2]],
2401+
[[1, 1], [0, 2]],
2402+
[[3, -1], [1, 2]],
2403+
],
2404+
[
2405+
[[6, -4], [9, -6]],
2406+
[[15, 1], [15, 1]],
24042407
],
2405-
[[6, -4, 27], [9, -6, 15], [15, 1, 11]],
24062408
),
24072409
],
24082410
ids=[
2409-
"2D_Matrix_Empty_Vector",
2410-
"2D_Matrix_1D_Vector",
2411-
"3D_Matrix_and_Vectors",
2411+
"2D_Matrix_Empty_RHS",
2412+
"2D_Matrix_1D_RHS",
2413+
"3D_Matrix_and_3D_RHS",
24122414
],
24132415
)
2414-
def test_solve(matrix, vector, device):
2416+
def test_solve(matrix, rhs, device):
24152417
a_np = numpy.array(matrix)
2416-
b_np = numpy.array(vector)
2418+
b_np = numpy.array(rhs)
24172419

24182420
a_dp = dpnp.array(a_np, device=device)
24192421
b_dp = dpnp.array(b_np, device=device)
24202422

2421-
# In numpy 2.0 the broadcast ambiguity has been removed and now
2422-
# b is treaded as a single vector if and only if it is 1-dimensional;
2423-
# for other cases this signature must be followed
2424-
# (..., m, m), (..., m, n) -> (..., m, n)
2425-
# https://github.com/numpy/numpy/pull/25914
2426-
if a_dp.ndim > 2 and numpy.lib.NumpyVersion(numpy.__version__) >= "2.0.0":
2427-
pytest.skip("SAT-6928")
2428-
24292423
result = dpnp.linalg.solve(a_dp, b_dp)
24302424
expected = numpy.linalg.solve(a_np, b_np)
24312425
assert_dtype_allclose(result, expected)

dpnp/tests/test_usm_type.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1285,37 +1285,39 @@ def test_fftshift(self, func, usm_type):
12851285
"usm_type_matrix", list_of_usm_types, ids=list_of_usm_types
12861286
)
12871287
@pytest.mark.parametrize(
1288-
"usm_type_vector", list_of_usm_types, ids=list_of_usm_types
1288+
"usm_type_rhs", list_of_usm_types, ids=list_of_usm_types
12891289
)
12901290
@pytest.mark.parametrize(
1291-
"matrix, vector",
1291+
"matrix, rhs",
12921292
[
1293-
([[1, 2], [3, 5]], dp.empty((2, 0))),
1293+
([[1, 2], [3, 5]], numpy.empty((2, 0))),
12941294
([[1, 2], [3, 5]], [1, 2]),
12951295
(
12961296
[
1297-
[[1, 1, 1], [0, 2, 5], [2, 5, -1]],
1298-
[[3, -1, 1], [1, 2, 3], [2, 3, 1]],
1299-
[[1, 4, 1], [1, 2, -2], [4, 1, 2]],
1297+
[[1, 1], [0, 2]],
1298+
[[3, -1], [1, 2]],
1299+
],
1300+
[
1301+
[[6, -4], [9, -6]],
1302+
[[15, 1], [15, 1]],
13001303
],
1301-
[[6, -4, 27], [9, -6, 15], [15, 1, 11]],
13021304
),
13031305
],
13041306
ids=[
1305-
"2D_Matrix_Empty_Vector",
1306-
"2D_Matrix_1D_Vector",
1307-
"3D_Matrix_and_Vectors",
1307+
"2D_Matrix_Empty_RHS",
1308+
"2D_Matrix_1D_RHS",
1309+
"3D_Matrix_and_3D_RHS",
13081310
],
13091311
)
1310-
def test_solve(matrix, vector, usm_type_matrix, usm_type_vector):
1312+
def test_solve(matrix, rhs, usm_type_matrix, usm_type_rhs):
13111313
x = dp.array(matrix, usm_type=usm_type_matrix)
1312-
y = dp.array(vector, usm_type=usm_type_vector)
1314+
y = dp.array(rhs, usm_type=usm_type_rhs)
13131315
z = dp.linalg.solve(x, y)
13141316

13151317
assert x.usm_type == usm_type_matrix
1316-
assert y.usm_type == usm_type_vector
1318+
assert y.usm_type == usm_type_rhs
13171319
assert z.usm_type == du.get_coerced_usm_type(
1318-
[usm_type_matrix, usm_type_vector]
1320+
[usm_type_matrix, usm_type_rhs]
13191321
)
13201322

13211323

0 commit comments

Comments
 (0)