Skip to content

Commit ae17d9b

Browse files
committed
Test namedtuple returned by affected linalg functions
1 parent faf1332 commit ae17d9b

File tree

1 file changed

+52
-22
lines changed

1 file changed

+52
-22
lines changed

dpnp/tests/test_linalg.py

Lines changed: 52 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -521,7 +521,8 @@ def test_eigenvalues(self, func, shape, dtype, order):
521521
# we verify them through the eigen equation A*v=w*v.
522522
if func in ("eig", "eigh"):
523523
w, _ = getattr(numpy.linalg, func)(a)
524-
w_dp, v_dp = getattr(dpnp.linalg, func)(a_dp)
524+
result = getattr(dpnp.linalg, func)(a_dp)
525+
w_dp, v_dp = result.eigenvalues, result.eigenvectors
525526

526527
self.assert_eigen_decomposition(a_dp, w_dp, v_dp)
527528

@@ -545,7 +546,8 @@ def test_eigenvalue_empty(self, func, shape, dtype):
545546

546547
if func == "eig":
547548
w, v = getattr(numpy.linalg, func)(a_np)
548-
w_dp, v_dp = getattr(dpnp.linalg, func)(a_dp)
549+
result = getattr(dpnp.linalg, func)(a_dp)
550+
w_dp, v_dp = result.eigenvalues, result.eigenvectors
549551

550552
assert_dtype_allclose(v_dp, v)
551553

@@ -2388,16 +2390,18 @@ def test_qr(self, dtype, shape, mode):
23882390
dpnp_r = dpnp.linalg.qr(ia, mode)
23892391
else:
23902392
np_q, np_r = numpy.linalg.qr(a, mode)
2391-
dpnp_q, dpnp_r = dpnp.linalg.qr(ia, mode)
23922393

23932394
# check decomposition
23942395
if mode in ("complete", "reduced"):
2396+
result = dpnp.linalg.qr(ia, mode)
2397+
dpnp_q, dpnp_r = result.Q, result.R
23952398
assert_almost_equal(
23962399
dpnp.matmul(dpnp_q, dpnp_r),
23972400
a,
23982401
decimal=5,
23992402
)
24002403
else: # mode=="raw"
2404+
dpnp_q, dpnp_r = dpnp.linalg.qr(ia, mode)
24012405
assert_dtype_allclose(dpnp_q, np_q)
24022406

24032407
if mode in ("raw", "r"):
@@ -2421,15 +2425,18 @@ def test_qr_large(self, dtype, shape, mode):
24212425
dpnp_r = dpnp.linalg.qr(ia, mode)
24222426
else:
24232427
np_q, np_r = numpy.linalg.qr(a, mode)
2424-
dpnp_q, dpnp_r = dpnp.linalg.qr(ia, mode)
2428+
24252429
# check decomposition
24262430
if mode in ("complete", "reduced"):
2431+
result = dpnp.linalg.qr(ia, mode)
2432+
dpnp_q, dpnp_r = result.Q, result.R
24272433
assert_almost_equal(
24282434
dpnp.matmul(dpnp_q, dpnp_r),
24292435
a,
24302436
decimal=5,
24312437
)
24322438
else: # mode=="raw"
2439+
dpnp_q, dpnp_r = dpnp.linalg.qr(ia, mode)
24332440
assert_allclose(np_q, dpnp_q, atol=1e-4)
24342441
if mode in ("raw", "r"):
24352442
assert_allclose(np_r, dpnp_r, atol=1e-4)
@@ -2457,7 +2464,12 @@ def test_qr_empty(self, dtype, shape, mode):
24572464
dpnp_r = dpnp.linalg.qr(ia, mode)
24582465
else:
24592466
np_q, np_r = numpy.linalg.qr(a, mode)
2460-
dpnp_q, dpnp_r = dpnp.linalg.qr(ia, mode)
2467+
2468+
if mode in ("complete", "reduced"):
2469+
result = dpnp.linalg.qr(ia, mode)
2470+
dpnp_q, dpnp_r = result.Q, result.R
2471+
else:
2472+
dpnp_q, dpnp_r = dpnp.linalg.qr(ia, mode)
24612473

24622474
assert_dtype_allclose(dpnp_q, np_q)
24632475

@@ -2474,7 +2486,12 @@ def test_qr_strides(self, mode):
24742486
dpnp_r = dpnp.linalg.qr(ia[::2, ::2], mode)
24752487
else:
24762488
np_q, np_r = numpy.linalg.qr(a[::2, ::2], mode)
2477-
dpnp_q, dpnp_r = dpnp.linalg.qr(ia[::2, ::2], mode)
2489+
2490+
if mode in ("complete", "reduced"):
2491+
result = dpnp.linalg.qr(ia[::2, ::2], mode)
2492+
dpnp_q, dpnp_r = result.Q, result.R
2493+
else:
2494+
dpnp_q, dpnp_r = dpnp.linalg.qr(ia[::2, ::2], mode)
24782495

24792496
assert_dtype_allclose(dpnp_q, np_q)
24802497

@@ -2486,7 +2503,12 @@ def test_qr_strides(self, mode):
24862503
dpnp_r = dpnp.linalg.qr(ia[::-2, ::-2], mode)
24872504
else:
24882505
np_q, np_r = numpy.linalg.qr(a[::-2, ::-2], mode)
2489-
dpnp_q, dpnp_r = dpnp.linalg.qr(ia[::-2, ::-2], mode)
2506+
2507+
if mode in ("complete", "reduced"):
2508+
result = dpnp.linalg.qr(ia[::-2, ::-2], mode)
2509+
dpnp_q, dpnp_r = result.Q, result.R
2510+
else:
2511+
dpnp_q, dpnp_r = dpnp.linalg.qr(ia[::-2, ::-2], mode)
24902512

24912513
assert_dtype_allclose(dpnp_q, np_q)
24922514

@@ -2660,7 +2682,8 @@ def test_slogdet_2d(self, dtype):
26602682
a_dp = dpnp.array(a_np)
26612683

26622684
sign_expected, logdet_expected = numpy.linalg.slogdet(a_np)
2663-
sign_result, logdet_result = dpnp.linalg.slogdet(a_dp)
2685+
result = dpnp.linalg.slogdet(a_dp)
2686+
sign_result, logdet_result = result.sign, result.logabsdet
26642687

26652688
assert_allclose(sign_expected, sign_result)
26662689
assert_allclose(logdet_expected, logdet_result, rtol=1e-3, atol=1e-4)
@@ -2678,7 +2701,8 @@ def test_slogdet_3d(self, dtype):
26782701
a_dp = dpnp.array(a_np)
26792702

26802703
sign_expected, logdet_expected = numpy.linalg.slogdet(a_np)
2681-
sign_result, logdet_result = dpnp.linalg.slogdet(a_dp)
2704+
result = dpnp.linalg.slogdet(a_dp)
2705+
sign_result, logdet_result = result.sign, result.logabsdet
26822706

26832707
assert_allclose(sign_expected, sign_result)
26842708
assert_allclose(logdet_expected, logdet_result, rtol=1e-3, atol=1e-4)
@@ -2698,13 +2722,15 @@ def test_slogdet_strides(self):
26982722

26992723
# positive strides
27002724
sign_expected, logdet_expected = numpy.linalg.slogdet(a_np[::2, ::2])
2701-
sign_result, logdet_result = dpnp.linalg.slogdet(a_dp[::2, ::2])
2725+
result = dpnp.linalg.slogdet(a_dp[::2, ::2])
2726+
sign_result, logdet_result = result.sign, result.logabsdet
27022727
assert_allclose(sign_expected, sign_result)
27032728
assert_allclose(logdet_expected, logdet_result, rtol=1e-3, atol=1e-4)
27042729

27052730
# negative strides
27062731
sign_expected, logdet_expected = numpy.linalg.slogdet(a_np[::-2, ::-2])
2707-
sign_result, logdet_result = dpnp.linalg.slogdet(a_dp[::-2, ::-2])
2732+
result = dpnp.linalg.slogdet(a_dp[::-2, ::-2])
2733+
sign_result, logdet_result = result.sign, result.logabsdet
27082734
assert_allclose(sign_expected, sign_result)
27092735
assert_allclose(logdet_expected, logdet_result, rtol=1e-3, atol=1e-4)
27102736

@@ -2732,7 +2758,8 @@ def test_slogdet_singular_matrix(self, matrix):
27322758
a_dp = dpnp.array(a_np)
27332759

27342760
sign_expected, logdet_expected = numpy.linalg.slogdet(a_np)
2735-
sign_result, logdet_result = dpnp.linalg.slogdet(a_dp)
2761+
result = dpnp.linalg.slogdet(a_dp)
2762+
sign_result, logdet_result = result.sign, result.logabsdet
27362763

27372764
assert_allclose(sign_expected, sign_result)
27382765
assert_allclose(logdet_expected, logdet_result, rtol=1e-3, atol=1e-4)
@@ -2748,7 +2775,8 @@ def test_slogdet_singular_matrix_3D(self):
27482775
a_dp = dpnp.array(a_np)
27492776

27502777
sign_expected, logdet_expected = numpy.linalg.slogdet(a_np)
2751-
sign_result, logdet_result = dpnp.linalg.slogdet(a_dp)
2778+
result = dpnp.linalg.slogdet(a_dp)
2779+
sign_result, logdet_result = result.sign, result.logabsdet
27522780

27532781
assert_allclose(sign_expected, sign_result)
27542782
assert_allclose(logdet_expected, logdet_result, rtol=1e-3, atol=1e-4)
@@ -2841,13 +2869,14 @@ def test_svd(self, dtype, shape):
28412869
a = numpy.arange(shape[0] * shape[1], dtype=dtype).reshape(shape)
28422870
dp_a = dpnp.array(a)
28432871

2844-
np_u, np_s, np_vt = numpy.linalg.svd(a)
2845-
dp_u, dp_s, dp_vt = dpnp.linalg.svd(dp_a)
2872+
np_u, np_s, np_vh = numpy.linalg.svd(a)
2873+
result = dpnp.linalg.svd(dp_a)
2874+
dp_u, dp_s, dp_vh = result.U, result.S, result.Vh
28462875

2847-
self.check_types_shapes(dp_u, dp_s, dp_vt, np_u, np_s, np_vt)
2876+
self.check_types_shapes(dp_u, dp_s, dp_vh, np_u, np_s, np_vh)
28482877
self.get_tol(dtype)
28492878
self.check_decomposition(
2850-
dp_a, dp_u, dp_s, dp_vt, np_u, np_s, np_vt, True
2879+
dp_a, dp_u, dp_s, dp_vh, np_u, np_s, np_vh, True
28512880
)
28522881

28532882
@pytest.mark.parametrize("dtype", get_float_complex_dtypes())
@@ -2860,25 +2889,26 @@ def test_svd_hermitian(self, dtype, compute_vt, shape):
28602889
dp_a = dpnp.array(a)
28612890

28622891
if compute_vt:
2863-
np_u, np_s, np_vt = numpy.linalg.svd(
2892+
np_u, np_s, np_vh = numpy.linalg.svd(
28642893
a, compute_uv=compute_vt, hermitian=True
28652894
)
2866-
dp_u, dp_s, dp_vt = dpnp.linalg.svd(
2895+
result = dpnp.linalg.svd(
28672896
dp_a, compute_uv=compute_vt, hermitian=True
28682897
)
2898+
dp_u, dp_s, dp_vh = result.U, result.S, result.Vh
28692899
else:
28702900
np_s = numpy.linalg.svd(a, compute_uv=compute_vt, hermitian=True)
28712901
dp_s = dpnp.linalg.svd(dp_a, compute_uv=compute_vt, hermitian=True)
2872-
np_u = np_vt = dp_u = dp_vt = None
2902+
np_u = np_vh = dp_u = dp_vh = None
28732903

28742904
self.check_types_shapes(
2875-
dp_u, dp_s, dp_vt, np_u, np_s, np_vt, compute_vt
2905+
dp_u, dp_s, dp_vh, np_u, np_s, np_vh, compute_vt
28762906
)
28772907

28782908
self.get_tol(dtype)
28792909

28802910
self.check_decomposition(
2881-
dp_a, dp_u, dp_s, dp_vt, np_u, np_s, np_vt, compute_vt
2911+
dp_a, dp_u, dp_s, dp_vh, np_u, np_s, np_vh, compute_vt
28822912
)
28832913

28842914
def test_svd_errors(self):

0 commit comments

Comments
 (0)