Skip to content

Commit 0ae4f6b

Browse files
Test: update dpnp_svd/_batched_svd to handle case m<n
1 parent feae44f commit 0ae4f6b

File tree

1 file changed

+55
-9
lines changed

1 file changed

+55
-9
lines changed

dpnp/linalg/dpnp_utils_linalg.py

Lines changed: 55 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -532,16 +532,29 @@ def _batched_svd(
532532
batch_shape_orig,
533533
)
534534

535+
if m < n:
536+
trans_flag = True
537+
else:
538+
trans_flag = False
539+
535540
k = min(m, n)
536541
if compute_uv:
537542
if full_matrices:
538-
u_shape = (m, m) + (batch_size,)
539-
vt_shape = (n, n) + (batch_size,)
543+
if trans_flag:
544+
u_shape = (n, n) + (batch_size,)
545+
vt_shape = (m, m) + (batch_size,)
546+
else:
547+
u_shape = (m, m) + (batch_size,)
548+
vt_shape = (n, n) + (batch_size,)
540549
jobu = ord("A")
541550
jobvt = ord("A")
542551
else:
543-
u_shape = (m, k) + (batch_size,)
544-
vt_shape = (k, n) + (batch_size,)
552+
if trans_flag:
553+
u_shape = (n, k) + (batch_size,)
554+
vt_shape = (m, k) + (batch_size,)
555+
else:
556+
u_shape = (m, k) + (batch_size,)
557+
vt_shape = (k, n) + (batch_size,)
545558
jobu = ord("S")
546559
jobvt = ord("S")
547560
else:
@@ -554,7 +567,10 @@ def _batched_svd(
554567

555568
# Reorder the elements by moving the last two axes of `a` to the front
556569
# to match fortran-like array order which is assumed by gesvd.
557-
a = dpnp.moveaxis(a, (-2, -1), (0, 1))
570+
if trans_flag:
571+
a = dpnp.moveaxis(a, (-1, -2), (0, 1))
572+
else:
573+
a = dpnp.moveaxis(a, (-2, -1), (0, 1))
558574

559575
# oneMKL LAPACK gesvd destroys `a` and assumes fortran-like array
560576
# as input.
@@ -607,6 +623,17 @@ def _batched_svd(
607623
# gesvd call writes `u_h` and `vt_h` in Fortran order;
608624
# reorder the axes to match C order by moving the last axis
609625
# to the front
626+
if trans_flag:
627+
u = dpnp.moveaxis(u_h, (0, 2), (2, 0))
628+
vt = dpnp.moveaxis(vt_h, (0, 2), (2, 0))
629+
if a_ndim > 3:
630+
u = u.reshape(batch_shape_orig + u.shape[-2:])
631+
vt = vt.reshape(batch_shape_orig + vt.shape[-2:])
632+
# dpnp.moveaxis can make the array non-contiguous if it is not 2D
633+
# Convert to contiguous to align with NumPy
634+
u = dpnp.ascontiguousarray(u)
635+
vt = dpnp.ascontiguousarray(vt)
636+
return vt, s, u
610637
u = dpnp.moveaxis(u_h, -1, 0)
611638
vt = dpnp.moveaxis(vt_h, -1, 0)
612639
if a_ndim > 3:
@@ -2695,6 +2722,13 @@ def dpnp_svd(
26952722
a, uv_type, s_type, full_matrices, compute_uv, exec_q, usm_type
26962723
)
26972724

2725+
if m < n:
2726+
a = a.transpose()
2727+
trans_flag = True
2728+
else:
2729+
a = a
2730+
trans_flag = False
2731+
26982732
# oneMKL LAPACK gesvd destroys `a` and assumes fortran-like array as input.
26992733
# Allocate 'F' order memory for dpnp arrays to comply with
27002734
# these requirements.
@@ -2719,13 +2753,21 @@ def dpnp_svd(
27192753
k = min(m, n)
27202754
if compute_uv:
27212755
if full_matrices:
2722-
u_shape = (m, m)
2723-
vt_shape = (n, n)
2756+
if trans_flag:
2757+
u_shape = (n, n)
2758+
vt_shape = (m, m)
2759+
else:
2760+
u_shape = (m, m)
2761+
vt_shape = (n, n)
27242762
jobu = ord("A")
27252763
jobvt = ord("A")
27262764
else:
2727-
u_shape = (m, k)
2728-
vt_shape = (k, n)
2765+
if trans_flag:
2766+
u_shape = (n, k)
2767+
vt_shape = (m, k)
2768+
else:
2769+
u_shape = (m, k)
2770+
vt_shape = (k, n)
27292771
jobu = ord("S")
27302772
jobvt = ord("S")
27312773
else:
@@ -2763,6 +2805,10 @@ def dpnp_svd(
27632805
if compute_uv:
27642806
# gesvd call writes `u_h` and `vt_h` in Fortran order;
27652807
# Convert to contiguous to align with NumPy
2808+
if trans_flag:
2809+
u_h = u_h.transpose()
2810+
vt_h = vt_h.transpose()
2811+
return vt_h, s_h, u_h
27662812
u_h = dpnp.ascontiguousarray(u_h)
27672813
vt_h = dpnp.ascontiguousarray(vt_h)
27682814
return u_h, s_h, vt_h

0 commit comments

Comments
 (0)