Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 55 additions & 9 deletions dpnp/linalg/dpnp_utils_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,16 +532,29 @@ def _batched_svd(
batch_shape_orig,
)

if m < n:
trans_flag = True
else:
trans_flag = False

k = min(m, n)
if compute_uv:
if full_matrices:
u_shape = (m, m) + (batch_size,)
vt_shape = (n, n) + (batch_size,)
if trans_flag:
u_shape = (n, n) + (batch_size,)
vt_shape = (m, m) + (batch_size,)
else:
u_shape = (m, m) + (batch_size,)
vt_shape = (n, n) + (batch_size,)
jobu = ord("A")
jobvt = ord("A")
else:
u_shape = (m, k) + (batch_size,)
vt_shape = (k, n) + (batch_size,)
if trans_flag:
u_shape = (n, k) + (batch_size,)
vt_shape = (m, k) + (batch_size,)
else:
u_shape = (m, k) + (batch_size,)
vt_shape = (k, n) + (batch_size,)
jobu = ord("S")
jobvt = ord("S")
else:
Expand All @@ -554,7 +567,10 @@ def _batched_svd(

# Reorder the elements by moving the last two axes of `a` to the front
# to match fortran-like array order which is assumed by gesvd.
a = dpnp.moveaxis(a, (-2, -1), (0, 1))
if trans_flag:
a = dpnp.moveaxis(a, (-1, -2), (0, 1))
else:
a = dpnp.moveaxis(a, (-2, -1), (0, 1))

# oneMKL LAPACK gesvd destroys `a` and assumes fortran-like array
# as input.
Expand Down Expand Up @@ -607,6 +623,17 @@ def _batched_svd(
# gesvd call writes `u_h` and `vt_h` in Fortran order;
# reorder the axes to match C order by moving the last axis
# to the front
if trans_flag:
u = dpnp.moveaxis(u_h, (0, 2), (2, 0))
vt = dpnp.moveaxis(vt_h, (0, 2), (2, 0))
if a_ndim > 3:
u = u.reshape(batch_shape_orig + u.shape[-2:])
vt = vt.reshape(batch_shape_orig + vt.shape[-2:])
# dpnp.moveaxis can make the array non-contiguous if it is not 2D
# Convert to contiguous to align with NumPy
u = dpnp.ascontiguousarray(u)
vt = dpnp.ascontiguousarray(vt)
return vt, s, u
u = dpnp.moveaxis(u_h, -1, 0)
vt = dpnp.moveaxis(vt_h, -1, 0)
if a_ndim > 3:
Expand Down Expand Up @@ -2695,6 +2722,13 @@ def dpnp_svd(
a, uv_type, s_type, full_matrices, compute_uv, exec_q, usm_type
)

if m < n:
a = a.transpose()
trans_flag = True
else:
a = a
trans_flag = False

# oneMKL LAPACK gesvd destroys `a` and assumes fortran-like array as input.
# Allocate 'F' order memory for dpnp arrays to comply with
# these requirements.
Expand All @@ -2719,13 +2753,21 @@ def dpnp_svd(
k = min(m, n)
if compute_uv:
if full_matrices:
u_shape = (m, m)
vt_shape = (n, n)
if trans_flag:
u_shape = (n, n)
vt_shape = (m, m)
else:
u_shape = (m, m)
vt_shape = (n, n)
jobu = ord("A")
jobvt = ord("A")
else:
u_shape = (m, k)
vt_shape = (k, n)
if trans_flag:
u_shape = (n, k)
vt_shape = (m, k)
else:
u_shape = (m, k)
vt_shape = (k, n)
jobu = ord("S")
jobvt = ord("S")
else:
Expand Down Expand Up @@ -2763,6 +2805,10 @@ def dpnp_svd(
if compute_uv:
# gesvd call writes `u_h` and `vt_h` in Fortran order;
# Convert to contiguous to align with NumPy
if trans_flag:
u_h = u_h.transpose()
vt_h = vt_h.transpose()
return vt_h, s_h, u_h
u_h = dpnp.ascontiguousarray(u_h)
vt_h = dpnp.ascontiguousarray(vt_h)
return u_h, s_h, vt_h
Expand Down
15 changes: 0 additions & 15 deletions dpnp/tests/third_party/cupy/linalg_tests/test_decomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,12 +280,6 @@ def test_svd_rank2_empty_array_compute_uv_false(self, xp):
array, full_matrices=self.full_matrices, compute_uv=False
)

# The issue was expected to be resolved once CMPLRLLVM-53771 is available,
# which has to be included in DPC++ 2024.1.0, but problem still exists
# on Windows
@pytest.mark.skipif(
is_cpu_device() and is_win_platform(), reason="SAT-7145"
)
@_condition.repeat(3, 10)
def test_svd_rank3(self):
self.check_usv((2, 3, 4))
Expand All @@ -295,9 +289,6 @@ def test_svd_rank3(self):
self.check_usv((2, 4, 3))
self.check_usv((2, 32, 32))

@pytest.mark.skipif(
is_cpu_device() and is_win_platform(), reason="SAT-7145"
)
@_condition.repeat(3, 10)
def test_svd_rank3_loop(self):
# This tests the loop-based batched gesvd on CUDA (_gesvd_batched)
Expand Down Expand Up @@ -345,9 +336,6 @@ def test_svd_rank3_empty_array_compute_uv_false2(self, xp):
array, full_matrices=self.full_matrices, compute_uv=False
)

@pytest.mark.skipif(
is_cpu_device() and is_win_platform(), reason="SAT-7145"
)
@_condition.repeat(3, 10)
def test_svd_rank4(self):
self.check_usv((2, 2, 3, 4))
Expand All @@ -357,9 +345,6 @@ def test_svd_rank4(self):
self.check_usv((2, 2, 4, 3))
self.check_usv((2, 2, 32, 32))

@pytest.mark.skipif(
is_cpu_device() and is_win_platform(), reason="SAT-7145"
)
@_condition.repeat(3, 10)
def test_svd_rank4_loop(self):
# This tests the loop-based batched gesvd on CUDA (_gesvd_batched)
Expand Down
Loading