diff --git a/dpnp/linalg/dpnp_utils_linalg.py b/dpnp/linalg/dpnp_utils_linalg.py index 2d7804234dec..0302c4bf20dc 100644 --- a/dpnp/linalg/dpnp_utils_linalg.py +++ b/dpnp/linalg/dpnp_utils_linalg.py @@ -532,16 +532,29 @@ def _batched_svd( batch_shape_orig, ) + if m < n: + trans_flag = True + else: + trans_flag = False + k = min(m, n) if compute_uv: if full_matrices: - u_shape = (m, m) + (batch_size,) - vt_shape = (n, n) + (batch_size,) + if trans_flag: + u_shape = (n, n) + (batch_size,) + vt_shape = (m, m) + (batch_size,) + else: + u_shape = (m, m) + (batch_size,) + vt_shape = (n, n) + (batch_size,) jobu = ord("A") jobvt = ord("A") else: - u_shape = (m, k) + (batch_size,) - vt_shape = (k, n) + (batch_size,) + if trans_flag: + u_shape = (n, k) + (batch_size,) + vt_shape = (m, k) + (batch_size,) + else: + u_shape = (m, k) + (batch_size,) + vt_shape = (k, n) + (batch_size,) jobu = ord("S") jobvt = ord("S") else: @@ -554,7 +567,10 @@ def _batched_svd( # Reorder the elements by moving the last two axes of `a` to the front # to match fortran-like array order which is assumed by gesvd. - a = dpnp.moveaxis(a, (-2, -1), (0, 1)) + if trans_flag: + a = dpnp.moveaxis(a, (-1, -2), (0, 1)) + else: + a = dpnp.moveaxis(a, (-2, -1), (0, 1)) # oneMKL LAPACK gesvd destroys `a` and assumes fortran-like array # as input. @@ -607,6 +623,17 @@ def _batched_svd( # gesvd call writes `u_h` and `vt_h` in Fortran order; # reorder the axes to match C order by moving the last axis # to the front + if trans_flag: + u = dpnp.moveaxis(u_h, (0, 2), (2, 0)) + vt = dpnp.moveaxis(vt_h, (0, 2), (2, 0)) + if a_ndim > 3: + u = u.reshape(batch_shape_orig + u.shape[-2:]) + vt = vt.reshape(batch_shape_orig + vt.shape[-2:]) + # dpnp.moveaxis can make the array non-contiguous if it is not 2D + # Convert to contiguous to align with NumPy + u = dpnp.ascontiguousarray(u) + vt = dpnp.ascontiguousarray(vt) + return vt, s, u u = dpnp.moveaxis(u_h, -1, 0) vt = dpnp.moveaxis(vt_h, -1, 0) if a_ndim > 3: @@ -2695,6 +2722,13 @@ def dpnp_svd( a, uv_type, s_type, full_matrices, compute_uv, exec_q, usm_type ) + if m < n: + a = a.transpose() + trans_flag = True + else: + a = a + trans_flag = False + # oneMKL LAPACK gesvd destroys `a` and assumes fortran-like array as input. # Allocate 'F' order memory for dpnp arrays to comply with # these requirements. @@ -2719,13 +2753,21 @@ def dpnp_svd( k = min(m, n) if compute_uv: if full_matrices: - u_shape = (m, m) - vt_shape = (n, n) + if trans_flag: + u_shape = (n, n) + vt_shape = (m, m) + else: + u_shape = (m, m) + vt_shape = (n, n) jobu = ord("A") jobvt = ord("A") else: - u_shape = (m, k) - vt_shape = (k, n) + if trans_flag: + u_shape = (n, k) + vt_shape = (m, k) + else: + u_shape = (m, k) + vt_shape = (k, n) jobu = ord("S") jobvt = ord("S") else: @@ -2763,6 +2805,10 @@ def dpnp_svd( if compute_uv: # gesvd call writes `u_h` and `vt_h` in Fortran order; # Convert to contiguous to align with NumPy + if trans_flag: + u_h = u_h.transpose() + vt_h = vt_h.transpose() + return vt_h, s_h, u_h u_h = dpnp.ascontiguousarray(u_h) vt_h = dpnp.ascontiguousarray(vt_h) return u_h, s_h, vt_h diff --git a/dpnp/tests/third_party/cupy/linalg_tests/test_decomposition.py b/dpnp/tests/third_party/cupy/linalg_tests/test_decomposition.py index d2c3ac69aacc..43d623f75a54 100644 --- a/dpnp/tests/third_party/cupy/linalg_tests/test_decomposition.py +++ b/dpnp/tests/third_party/cupy/linalg_tests/test_decomposition.py @@ -280,12 +280,6 @@ def test_svd_rank2_empty_array_compute_uv_false(self, xp): array, full_matrices=self.full_matrices, compute_uv=False ) - # The issue was expected to be resolved once CMPLRLLVM-53771 is available, - # which has to be included in DPC++ 2024.1.0, but problem still exists - # on Windows - @pytest.mark.skipif( - is_cpu_device() and is_win_platform(), reason="SAT-7145" - ) @_condition.repeat(3, 10) def test_svd_rank3(self): self.check_usv((2, 3, 4)) @@ -295,9 +289,6 @@ def test_svd_rank3(self): self.check_usv((2, 4, 3)) self.check_usv((2, 32, 32)) - @pytest.mark.skipif( - is_cpu_device() and is_win_platform(), reason="SAT-7145" - ) @_condition.repeat(3, 10) def test_svd_rank3_loop(self): # This tests the loop-based batched gesvd on CUDA (_gesvd_batched) @@ -345,9 +336,6 @@ def test_svd_rank3_empty_array_compute_uv_false2(self, xp): array, full_matrices=self.full_matrices, compute_uv=False ) - @pytest.mark.skipif( - is_cpu_device() and is_win_platform(), reason="SAT-7145" - ) @_condition.repeat(3, 10) def test_svd_rank4(self): self.check_usv((2, 2, 3, 4)) @@ -357,9 +345,6 @@ def test_svd_rank4(self): self.check_usv((2, 2, 4, 3)) self.check_usv((2, 2, 32, 32)) - @pytest.mark.skipif( - is_cpu_device() and is_win_platform(), reason="SAT-7145" - ) @_condition.repeat(3, 10) def test_svd_rank4_loop(self): # This tests the loop-based batched gesvd on CUDA (_gesvd_batched)