@@ -99,6 +99,10 @@ def _raise_linalgerror_svd_nonconvergence(err, flag):
9999def _raise_linalgerror_lstsq (err , flag ):
100100 raise LinAlgError ("SVD did not converge in Linear Least Squares" )
101101
102+ def _raise_linalgerror_qr (err , flag ):
103+ raise LinAlgError ("Incorrect argument found while performing "
104+ "QR factorization" )
105+
102106def get_linalg_error_extobj (callback ):
103107 extobj = list (_linalg_error_extobj ) # make a copy
104108 extobj [2 ] = callback
@@ -776,15 +780,16 @@ def qr(a, mode='reduced'):
776780
777781 Parameters
778782 ----------
779- a : array_like, shape (M, N)
780- Matrix to be factored .
783+ a : array_like, shape (..., M, N)
784+ An array-like object with the dimensionality of at least 2 .
781785 mode : {'reduced', 'complete', 'r', 'raw'}, optional
782786 If K = min(M, N), then
783787
784- * 'reduced' : returns q, r with dimensions (M, K), (K, N) (default)
785- * 'complete' : returns q, r with dimensions (M, M), (M, N)
786- * 'r' : returns r only with dimensions (K, N)
787- * 'raw' : returns h, tau with dimensions (N, M), (K,)
788+ * 'reduced' : returns q, r with dimensions
789+ (..., M, K), (..., K, N) (default)
790+ * 'complete' : returns q, r with dimensions (..., M, M), (..., M, N)
791+ * 'r' : returns r only with dimensions (..., K, N)
792+ * 'raw' : returns h, tau with dimensions (..., N, M), (..., K,)
788793
789794 The options 'reduced', 'complete, and 'raw' are new in numpy 1.8,
790795 see the notes for more information. The default is 'reduced', and to
@@ -803,9 +808,13 @@ def qr(a, mode='reduced'):
803808 A matrix with orthonormal columns. When mode = 'complete' the
804809 result is an orthogonal/unitary matrix depending on whether or not
805810 a is real/complex. The determinant may be either +/- 1 in that
806- case.
811+ case. In case the number of dimensions in the input array is
812+ greater than 2 then a stack of the matrices with above properties
813+ is returned.
807814 r : ndarray of float or complex, optional
808- The upper-triangular matrix.
815+ The upper-triangular matrix or a stack of upper-triangular
816+ matrices if the number of dimensions in the input array is greater
817+ than 2.
809818 (h, tau) : ndarrays of np.double or np.cdouble, optional
810819 The array h contains the Householder reflectors that generate q
811820 along with r. The tau array contains scaling factors for the
@@ -853,6 +862,14 @@ def qr(a, mode='reduced'):
853862 >>> r2 = np.linalg.qr(a, mode='r')
854863 >>> np.allclose(r, r2) # mode='r' returns the same r as mode='full'
855864 True
865+ >>> a = np.random.normal(size=(3, 2, 2)) # Stack of 2 x 2 matrices as input
866+ >>> q, r = np.linalg.qr(a)
867+ >>> q.shape
868+ (3, 2, 2)
869+ >>> r.shape
870+ (3, 2, 2)
871+ >>> np.allclose(a, np.matmul(q, r))
872+ True
856873
857874 Example illustrating a common use of `qr`: solving of least squares
858875 problems
@@ -900,83 +917,58 @@ def qr(a, mode='reduced'):
900917 raise ValueError (f"Unrecognized mode '{ mode } '" )
901918
902919 a , wrap = _makearray (a )
903- _assert_2d (a )
904- m , n = a .shape
920+ _assert_stacked_2d (a )
921+ m , n = a .shape [ - 2 :]
905922 t , result_t = _commonType (a )
906- a = _fastCopyAndTranspose (t , a )
923+ a = a . astype (t , copy = True )
907924 a = _to_native_byte_order (a )
908925 mn = min (m , n )
909- tau = zeros ((mn ,), t )
910926
911- if isComplexType (t ):
912- lapack_routine = lapack_lite .zgeqrf
913- routine_name = 'zgeqrf'
927+ if m <= n :
928+ gufunc = _umath_linalg .qr_r_raw_m
914929 else :
915- lapack_routine = lapack_lite .dgeqrf
916- routine_name = 'dgeqrf'
917-
918- # calculate optimal size of work data 'work'
919- lwork = 1
920- work = zeros ((lwork ,), t )
921- results = lapack_routine (m , n , a , max (1 , m ), tau , work , - 1 , 0 )
922- if results ['info' ] != 0 :
923- raise LinAlgError ('%s returns %d' % (routine_name , results ['info' ]))
924-
925- # do qr decomposition
926- lwork = max (1 , n , int (abs (work [0 ])))
927- work = zeros ((lwork ,), t )
928- results = lapack_routine (m , n , a , max (1 , m ), tau , work , lwork , 0 )
929- if results ['info' ] != 0 :
930- raise LinAlgError ('%s returns %d' % (routine_name , results ['info' ]))
930+ gufunc = _umath_linalg .qr_r_raw_n
931+
932+ signature = 'D->D' if isComplexType (t ) else 'd->d'
933+ extobj = get_linalg_error_extobj (_raise_linalgerror_qr )
934+ tau = gufunc (a , signature = signature , extobj = extobj )
931935
932936 # handle modes that don't return q
933937 if mode == 'r' :
934- r = _fastCopyAndTranspose (result_t , a [:, :mn ])
935- return wrap (triu (r ))
938+ r = triu (a [..., :mn , :])
939+ r = r .astype (result_t , copy = False )
940+ return wrap (r )
936941
937942 if mode == 'raw' :
938- return a , tau
943+ q = transpose (a )
944+ q = q .astype (result_t , copy = False )
945+ tau = tau .astype (result_t , copy = False )
946+ return wrap (q ), tau
939947
940948 if mode == 'economic' :
941- if t != result_t :
942- a = a .astype (result_t , copy = False )
943- return wrap (a .T )
949+ a = a .astype (result_t , copy = False )
950+ return wrap (a )
944951
945- # generate q from a
952+ # mc is the number of columns in the resulting q
953+ # matrix. If the mode is complete then it is
954+ # same as number of rows, and if the mode is reduced,
955+ # then it is the minimum of number of rows and columns.
946956 if mode == 'complete' and m > n :
947957 mc = m
948- q = empty (( m , m ), t )
958+ gufunc = _umath_linalg . qr_complete
949959 else :
950960 mc = mn
951- q = empty ((n , m ), t )
952- q [:n ] = a
953-
954- if isComplexType (t ):
955- lapack_routine = lapack_lite .zungqr
956- routine_name = 'zungqr'
957- else :
958- lapack_routine = lapack_lite .dorgqr
959- routine_name = 'dorgqr'
961+ gufunc = _umath_linalg .qr_reduced
960962
961- # determine optimal lwork
962- lwork = 1
963- work = zeros ((lwork ,), t )
964- results = lapack_routine (m , mc , mn , q , max (1 , m ), tau , work , - 1 , 0 )
965- if results ['info' ] != 0 :
966- raise LinAlgError ('%s returns %d' % (routine_name , results ['info' ]))
967-
968- # compute q
969- lwork = max (1 , n , int (abs (work [0 ])))
970- work = zeros ((lwork ,), t )
971- results = lapack_routine (m , mc , mn , q , max (1 , m ), tau , work , lwork , 0 )
972- if results ['info' ] != 0 :
973- raise LinAlgError ('%s returns %d' % (routine_name , results ['info' ]))
974-
975- q = _fastCopyAndTranspose (result_t , q [:mc ])
976- r = _fastCopyAndTranspose (result_t , a [:, :mc ])
963+ signature = 'DD->D' if isComplexType (t ) else 'dd->d'
964+ extobj = get_linalg_error_extobj (_raise_linalgerror_qr )
965+ q = gufunc (a , tau , signature = signature , extobj = extobj )
966+ r = triu (a [..., :mc , :])
977967
978- return wrap (q ), wrap (triu (r ))
968+ q = q .astype (result_t , copy = False )
969+ r = r .astype (result_t , copy = False )
979970
971+ return wrap (q ), wrap (r )
980972
981973# Eigenvalues
982974
@@ -2173,7 +2165,7 @@ def lstsq(a, b, rcond="warn"):
21732165 equal to, or greater than its number of linearly independent columns).
21742166 If `a` is square and of full rank, then `x` (but for round-off error)
21752167 is the "exact" solution of the equation. Else, `x` minimizes the
2176- Euclidean 2-norm :math:`||b - ax||`. If there are multiple minimizing
2168+ Euclidean 2-norm :math:`||b - ax||`. If there are multiple minimizing
21772169 solutions, the one with the smallest 2-norm :math:`||x||` is returned.
21782170
21792171 Parameters
0 commit comments