@@ -532,16 +532,29 @@ def _batched_svd(
532532 batch_shape_orig ,
533533 )
534534
535+ if m < n :
536+ trans_flag = True
537+ else :
538+ trans_flag = False
539+
535540 k = min (m , n )
536541 if compute_uv :
537542 if full_matrices :
538- u_shape = (m , m ) + (batch_size ,)
539- vt_shape = (n , n ) + (batch_size ,)
543+ if trans_flag :
544+ u_shape = (n , n ) + (batch_size ,)
545+ vt_shape = (m , m ) + (batch_size ,)
546+ else :
547+ u_shape = (m , m ) + (batch_size ,)
548+ vt_shape = (n , n ) + (batch_size ,)
540549 jobu = ord ("A" )
541550 jobvt = ord ("A" )
542551 else :
543- u_shape = (m , k ) + (batch_size ,)
544- vt_shape = (k , n ) + (batch_size ,)
552+ if trans_flag :
553+ u_shape = (n , k ) + (batch_size ,)
554+ vt_shape = (m , k ) + (batch_size ,)
555+ else :
556+ u_shape = (m , k ) + (batch_size ,)
557+ vt_shape = (k , n ) + (batch_size ,)
545558 jobu = ord ("S" )
546559 jobvt = ord ("S" )
547560 else :
@@ -554,7 +567,10 @@ def _batched_svd(
554567
555568 # Reorder the elements by moving the last two axes of `a` to the front
556569 # to match fortran-like array order which is assumed by gesvd.
557- a = dpnp .moveaxis (a , (- 2 , - 1 ), (0 , 1 ))
570+ if trans_flag :
571+ a = dpnp .moveaxis (a , (- 1 , - 2 ), (0 , 1 ))
572+ else :
573+ a = dpnp .moveaxis (a , (- 2 , - 1 ), (0 , 1 ))
558574
559575 # oneMKL LAPACK gesvd destroys `a` and assumes fortran-like array
560576 # as input.
@@ -607,6 +623,17 @@ def _batched_svd(
607623 # gesvd call writes `u_h` and `vt_h` in Fortran order;
608624 # reorder the axes to match C order by moving the last axis
609625 # to the front
626+ if trans_flag :
627+ u = dpnp .moveaxis (u_h , (0 , 2 ), (2 , 0 ))
628+ vt = dpnp .moveaxis (vt_h , (0 , 2 ), (2 , 0 ))
629+ if a_ndim > 3 :
630+ u = u .reshape (batch_shape_orig + u .shape [- 2 :])
631+ vt = vt .reshape (batch_shape_orig + vt .shape [- 2 :])
632+ # dpnp.moveaxis can make the array non-contiguous if it is not 2D
633+ # Convert to contiguous to align with NumPy
634+ u = dpnp .ascontiguousarray (u )
635+ vt = dpnp .ascontiguousarray (vt )
636+ return vt , s , u
610637 u = dpnp .moveaxis (u_h , - 1 , 0 )
611638 vt = dpnp .moveaxis (vt_h , - 1 , 0 )
612639 if a_ndim > 3 :
@@ -2695,6 +2722,13 @@ def dpnp_svd(
26952722 a , uv_type , s_type , full_matrices , compute_uv , exec_q , usm_type
26962723 )
26972724
2725+ if m < n :
2726+ a = a .transpose ()
2727+ trans_flag = True
2728+ else :
2729+ a = a
2730+ trans_flag = False
2731+
26982732 # oneMKL LAPACK gesvd destroys `a` and assumes fortran-like array as input.
26992733 # Allocate 'F' order memory for dpnp arrays to comply with
27002734 # these requirements.
@@ -2719,13 +2753,21 @@ def dpnp_svd(
27192753 k = min (m , n )
27202754 if compute_uv :
27212755 if full_matrices :
2722- u_shape = (m , m )
2723- vt_shape = (n , n )
2756+ if trans_flag :
2757+ u_shape = (n , n )
2758+ vt_shape = (m , m )
2759+ else :
2760+ u_shape = (m , m )
2761+ vt_shape = (n , n )
27242762 jobu = ord ("A" )
27252763 jobvt = ord ("A" )
27262764 else :
2727- u_shape = (m , k )
2728- vt_shape = (k , n )
2765+ if trans_flag :
2766+ u_shape = (n , k )
2767+ vt_shape = (m , k )
2768+ else :
2769+ u_shape = (m , k )
2770+ vt_shape = (k , n )
27292771 jobu = ord ("S" )
27302772 jobvt = ord ("S" )
27312773 else :
@@ -2763,6 +2805,10 @@ def dpnp_svd(
27632805 if compute_uv :
27642806 # gesvd call writes `u_h` and `vt_h` in Fortran order;
27652807 # Convert to contiguous to align with NumPy
2808+ if trans_flag :
2809+ u_h = u_h .transpose ()
2810+ vt_h = vt_h .transpose ()
2811+ return vt_h , s_h , u_h
27662812 u_h = dpnp .ascontiguousarray (u_h )
27672813 vt_h = dpnp .ascontiguousarray (vt_h )
27682814 return u_h , s_h , vt_h
0 commit comments