Skip to content

Commit a4e931f

Browse files
authored
Merge pull request numpy#19151 from czgdp1807/stack_mat
ENH: Vectorising np.linalg.qr
2 parents f353371 + 6e405d5 commit a4e931f

File tree

4 files changed

+1093
-283
lines changed

4 files changed

+1093
-283
lines changed
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
`numpy.linalg.qr` accepts stacked matrices as inputs
2+
----------------------------------------------------
3+
4+
`numpy.linalg.qr` is able to produce results for stacked matrices as inputs.
5+
Moreover, the implementation of QR decomposition has been shifted to C
6+
from Python.

numpy/linalg/linalg.py

Lines changed: 58 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,10 @@ def _raise_linalgerror_svd_nonconvergence(err, flag):
9999
def _raise_linalgerror_lstsq(err, flag):
100100
raise LinAlgError("SVD did not converge in Linear Least Squares")
101101

102+
def _raise_linalgerror_qr(err, flag):
103+
raise LinAlgError("Incorrect argument found while performing "
104+
"QR factorization")
105+
102106
def get_linalg_error_extobj(callback):
103107
extobj = list(_linalg_error_extobj) # make a copy
104108
extobj[2] = callback
@@ -776,15 +780,16 @@ def qr(a, mode='reduced'):
776780
777781
Parameters
778782
----------
779-
a : array_like, shape (M, N)
780-
Matrix to be factored.
783+
a : array_like, shape (..., M, N)
784+
An array-like object with the dimensionality of at least 2.
781785
mode : {'reduced', 'complete', 'r', 'raw'}, optional
782786
If K = min(M, N), then
783787
784-
* 'reduced' : returns q, r with dimensions (M, K), (K, N) (default)
785-
* 'complete' : returns q, r with dimensions (M, M), (M, N)
786-
* 'r' : returns r only with dimensions (K, N)
787-
* 'raw' : returns h, tau with dimensions (N, M), (K,)
788+
* 'reduced' : returns q, r with dimensions
789+
(..., M, K), (..., K, N) (default)
790+
* 'complete' : returns q, r with dimensions (..., M, M), (..., M, N)
791+
* 'r' : returns r only with dimensions (..., K, N)
792+
* 'raw' : returns h, tau with dimensions (..., N, M), (..., K,)
788793
789794
The options 'reduced', 'complete, and 'raw' are new in numpy 1.8,
790795
see the notes for more information. The default is 'reduced', and to
@@ -803,9 +808,13 @@ def qr(a, mode='reduced'):
803808
A matrix with orthonormal columns. When mode = 'complete' the
804809
result is an orthogonal/unitary matrix depending on whether or not
805810
a is real/complex. The determinant may be either +/- 1 in that
806-
case.
811+
case. In case the number of dimensions in the input array is
812+
greater than 2 then a stack of the matrices with above properties
813+
is returned.
807814
r : ndarray of float or complex, optional
808-
The upper-triangular matrix.
815+
The upper-triangular matrix or a stack of upper-triangular
816+
matrices if the number of dimensions in the input array is greater
817+
than 2.
809818
(h, tau) : ndarrays of np.double or np.cdouble, optional
810819
The array h contains the Householder reflectors that generate q
811820
along with r. The tau array contains scaling factors for the
@@ -853,6 +862,14 @@ def qr(a, mode='reduced'):
853862
>>> r2 = np.linalg.qr(a, mode='r')
854863
>>> np.allclose(r, r2) # mode='r' returns the same r as mode='full'
855864
True
865+
>>> a = np.random.normal(size=(3, 2, 2)) # Stack of 2 x 2 matrices as input
866+
>>> q, r = np.linalg.qr(a)
867+
>>> q.shape
868+
(3, 2, 2)
869+
>>> r.shape
870+
(3, 2, 2)
871+
>>> np.allclose(a, np.matmul(q, r))
872+
True
856873
857874
Example illustrating a common use of `qr`: solving of least squares
858875
problems
@@ -900,83 +917,58 @@ def qr(a, mode='reduced'):
900917
raise ValueError(f"Unrecognized mode '{mode}'")
901918

902919
a, wrap = _makearray(a)
903-
_assert_2d(a)
904-
m, n = a.shape
920+
_assert_stacked_2d(a)
921+
m, n = a.shape[-2:]
905922
t, result_t = _commonType(a)
906-
a = _fastCopyAndTranspose(t, a)
923+
a = a.astype(t, copy=True)
907924
a = _to_native_byte_order(a)
908925
mn = min(m, n)
909-
tau = zeros((mn,), t)
910926

911-
if isComplexType(t):
912-
lapack_routine = lapack_lite.zgeqrf
913-
routine_name = 'zgeqrf'
927+
if m <= n:
928+
gufunc = _umath_linalg.qr_r_raw_m
914929
else:
915-
lapack_routine = lapack_lite.dgeqrf
916-
routine_name = 'dgeqrf'
917-
918-
# calculate optimal size of work data 'work'
919-
lwork = 1
920-
work = zeros((lwork,), t)
921-
results = lapack_routine(m, n, a, max(1, m), tau, work, -1, 0)
922-
if results['info'] != 0:
923-
raise LinAlgError('%s returns %d' % (routine_name, results['info']))
924-
925-
# do qr decomposition
926-
lwork = max(1, n, int(abs(work[0])))
927-
work = zeros((lwork,), t)
928-
results = lapack_routine(m, n, a, max(1, m), tau, work, lwork, 0)
929-
if results['info'] != 0:
930-
raise LinAlgError('%s returns %d' % (routine_name, results['info']))
930+
gufunc = _umath_linalg.qr_r_raw_n
931+
932+
signature = 'D->D' if isComplexType(t) else 'd->d'
933+
extobj = get_linalg_error_extobj(_raise_linalgerror_qr)
934+
tau = gufunc(a, signature=signature, extobj=extobj)
931935

932936
# handle modes that don't return q
933937
if mode == 'r':
934-
r = _fastCopyAndTranspose(result_t, a[:, :mn])
935-
return wrap(triu(r))
938+
r = triu(a[..., :mn, :])
939+
r = r.astype(result_t, copy=False)
940+
return wrap(r)
936941

937942
if mode == 'raw':
938-
return a, tau
943+
q = transpose(a)
944+
q = q.astype(result_t, copy=False)
945+
tau = tau.astype(result_t, copy=False)
946+
return wrap(q), tau
939947

940948
if mode == 'economic':
941-
if t != result_t :
942-
a = a.astype(result_t, copy=False)
943-
return wrap(a.T)
949+
a = a.astype(result_t, copy=False)
950+
return wrap(a)
944951

945-
# generate q from a
952+
# mc is the number of columns in the resulting q
953+
# matrix. If the mode is complete then it is
954+
# same as number of rows, and if the mode is reduced,
955+
# then it is the minimum of number of rows and columns.
946956
if mode == 'complete' and m > n:
947957
mc = m
948-
q = empty((m, m), t)
958+
gufunc = _umath_linalg.qr_complete
949959
else:
950960
mc = mn
951-
q = empty((n, m), t)
952-
q[:n] = a
953-
954-
if isComplexType(t):
955-
lapack_routine = lapack_lite.zungqr
956-
routine_name = 'zungqr'
957-
else:
958-
lapack_routine = lapack_lite.dorgqr
959-
routine_name = 'dorgqr'
961+
gufunc = _umath_linalg.qr_reduced
960962

961-
# determine optimal lwork
962-
lwork = 1
963-
work = zeros((lwork,), t)
964-
results = lapack_routine(m, mc, mn, q, max(1, m), tau, work, -1, 0)
965-
if results['info'] != 0:
966-
raise LinAlgError('%s returns %d' % (routine_name, results['info']))
967-
968-
# compute q
969-
lwork = max(1, n, int(abs(work[0])))
970-
work = zeros((lwork,), t)
971-
results = lapack_routine(m, mc, mn, q, max(1, m), tau, work, lwork, 0)
972-
if results['info'] != 0:
973-
raise LinAlgError('%s returns %d' % (routine_name, results['info']))
974-
975-
q = _fastCopyAndTranspose(result_t, q[:mc])
976-
r = _fastCopyAndTranspose(result_t, a[:, :mc])
963+
signature = 'DD->D' if isComplexType(t) else 'dd->d'
964+
extobj = get_linalg_error_extobj(_raise_linalgerror_qr)
965+
q = gufunc(a, tau, signature=signature, extobj=extobj)
966+
r = triu(a[..., :mc, :])
977967

978-
return wrap(q), wrap(triu(r))
968+
q = q.astype(result_t, copy=False)
969+
r = r.astype(result_t, copy=False)
979970

971+
return wrap(q), wrap(r)
980972

981973
# Eigenvalues
982974

@@ -2173,7 +2165,7 @@ def lstsq(a, b, rcond="warn"):
21732165
equal to, or greater than its number of linearly independent columns).
21742166
If `a` is square and of full rank, then `x` (but for round-off error)
21752167
is the "exact" solution of the equation. Else, `x` minimizes the
2176-
Euclidean 2-norm :math:`||b - ax||`. If there are multiple minimizing
2168+
Euclidean 2-norm :math:`||b - ax||`. If there are multiple minimizing
21772169
solutions, the one with the smallest 2-norm :math:`||x||` is returned.
21782170
21792171
Parameters

numpy/linalg/tests/test_linalg.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
""" Test functions for linalg module
22
33
"""
4+
from numpy.core.fromnumeric import shape
45
import os
56
import sys
67
import itertools
@@ -11,6 +12,7 @@
1112

1213
import numpy as np
1314
from numpy import array, single, double, csingle, cdouble, dot, identity, matmul
15+
from numpy.core import swapaxes
1416
from numpy import multiply, atleast_2d, inf, asarray
1517
from numpy import linalg
1618
from numpy.linalg import matrix_power, norm, matrix_rank, multi_dot, LinAlgError
@@ -1710,6 +1712,66 @@ def test_mode_all_but_economic(self):
17101712
self.check_qr(m2)
17111713
self.check_qr(m2.T)
17121714

1715+
def check_qr_stacked(self, a):
1716+
# This test expects the argument `a` to be an ndarray or
1717+
# a subclass of an ndarray of inexact type.
1718+
a_type = type(a)
1719+
a_dtype = a.dtype
1720+
m, n = a.shape[-2:]
1721+
k = min(m, n)
1722+
1723+
# mode == 'complete'
1724+
q, r = linalg.qr(a, mode='complete')
1725+
assert_(q.dtype == a_dtype)
1726+
assert_(r.dtype == a_dtype)
1727+
assert_(isinstance(q, a_type))
1728+
assert_(isinstance(r, a_type))
1729+
assert_(q.shape[-2:] == (m, m))
1730+
assert_(r.shape[-2:] == (m, n))
1731+
assert_almost_equal(matmul(q, r), a)
1732+
I_mat = np.identity(q.shape[-1])
1733+
stack_I_mat = np.broadcast_to(I_mat,
1734+
q.shape[:-2] + (q.shape[-1],)*2)
1735+
assert_almost_equal(matmul(swapaxes(q, -1, -2).conj(), q), stack_I_mat)
1736+
assert_almost_equal(np.triu(r[..., :, :]), r)
1737+
1738+
# mode == 'reduced'
1739+
q1, r1 = linalg.qr(a, mode='reduced')
1740+
assert_(q1.dtype == a_dtype)
1741+
assert_(r1.dtype == a_dtype)
1742+
assert_(isinstance(q1, a_type))
1743+
assert_(isinstance(r1, a_type))
1744+
assert_(q1.shape[-2:] == (m, k))
1745+
assert_(r1.shape[-2:] == (k, n))
1746+
assert_almost_equal(matmul(q1, r1), a)
1747+
I_mat = np.identity(q1.shape[-1])
1748+
stack_I_mat = np.broadcast_to(I_mat,
1749+
q1.shape[:-2] + (q1.shape[-1],)*2)
1750+
assert_almost_equal(matmul(swapaxes(q1, -1, -2).conj(), q1),
1751+
stack_I_mat)
1752+
assert_almost_equal(np.triu(r1[..., :, :]), r1)
1753+
1754+
# mode == 'r'
1755+
r2 = linalg.qr(a, mode='r')
1756+
assert_(r2.dtype == a_dtype)
1757+
assert_(isinstance(r2, a_type))
1758+
assert_almost_equal(r2, r1)
1759+
1760+
@pytest.mark.parametrize("size", [
1761+
(3, 4), (4, 3), (4, 4),
1762+
(3, 0), (0, 3)])
1763+
@pytest.mark.parametrize("outer_size", [
1764+
(2, 2), (2,), (2, 3, 4)])
1765+
@pytest.mark.parametrize("dt", [
1766+
np.single, np.double,
1767+
np.csingle, np.cdouble])
1768+
def test_stacked_inputs(self, outer_size, size, dt):
1769+
1770+
A = np.random.normal(size=outer_size + size).astype(dt)
1771+
B = np.random.normal(size=outer_size + size).astype(dt)
1772+
self.check_qr_stacked(A)
1773+
self.check_qr_stacked(A + 1.j*B)
1774+
17131775

17141776
class TestCholesky:
17151777
# TODO: are there no other tests for cholesky?

0 commit comments

Comments
 (0)