Skip to content

Commit 22d34e9

Browse files
authored
Merge pull request numpy#27060 from WarrenWeckesser/coalesce-linalg-gufuncs
MAINT: linalg: Simplify some linalg gufuncs.
2 parents 08d6004 + 755e959 commit 22d34e9

File tree

2 files changed

+123
-107
lines changed

2 files changed

+123
-107
lines changed

numpy/linalg/_linalg.py

Lines changed: 6 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1092,15 +1092,10 @@ def qr(a, mode='reduced'):
10921092
a = _to_native_byte_order(a)
10931093
mn = min(m, n)
10941094

1095-
if m <= n:
1096-
gufunc = _umath_linalg.qr_r_raw_m
1097-
else:
1098-
gufunc = _umath_linalg.qr_r_raw_n
1099-
11001095
signature = 'D->D' if isComplexType(t) else 'd->d'
11011096
with errstate(call=_raise_linalgerror_qr, invalid='call',
11021097
over='ignore', divide='ignore', under='ignore'):
1103-
tau = gufunc(a, signature=signature)
1098+
tau = _umath_linalg.qr_r_raw(a, signature=signature)
11041099

11051100
# handle modes that don't return q
11061101
if mode == 'r':
@@ -1833,15 +1828,9 @@ def svd(a, full_matrices=True, compute_uv=True, hermitian=False):
18331828
m, n = a.shape[-2:]
18341829
if compute_uv:
18351830
if full_matrices:
1836-
if m < n:
1837-
gufunc = _umath_linalg.svd_m_f
1838-
else:
1839-
gufunc = _umath_linalg.svd_n_f
1831+
gufunc = _umath_linalg.svd_f
18401832
else:
1841-
if m < n:
1842-
gufunc = _umath_linalg.svd_m_s
1843-
else:
1844-
gufunc = _umath_linalg.svd_n_s
1833+
gufunc = _umath_linalg.svd_s
18451834

18461835
signature = 'D->DdD' if isComplexType(t) else 'd->ddd'
18471836
with errstate(call=_raise_linalgerror_svd_nonconvergence,
@@ -1853,16 +1842,11 @@ def svd(a, full_matrices=True, compute_uv=True, hermitian=False):
18531842
vh = vh.astype(result_t, copy=False)
18541843
return SVDResult(wrap(u), s, wrap(vh))
18551844
else:
1856-
if m < n:
1857-
gufunc = _umath_linalg.svd_m
1858-
else:
1859-
gufunc = _umath_linalg.svd_n
1860-
18611845
signature = 'D->d' if isComplexType(t) else 'd->d'
18621846
with errstate(call=_raise_linalgerror_svd_nonconvergence,
18631847
invalid='call', over='ignore', divide='ignore',
18641848
under='ignore'):
1865-
s = gufunc(a, signature=signature)
1849+
s = _umath_linalg.svd(a, signature=signature)
18661850
s = s.astype(_realType(result_t), copy=False)
18671851
return s
18681852

@@ -2570,11 +2554,6 @@ def lstsq(a, b, rcond=None):
25702554
if rcond is None:
25712555
rcond = finfo(t).eps * max(n, m)
25722556

2573-
if m <= n:
2574-
gufunc = _umath_linalg.lstsq_m
2575-
else:
2576-
gufunc = _umath_linalg.lstsq_n
2577-
25782557
signature = 'DDd->Ddid' if isComplexType(t) else 'ddd->ddid'
25792558
if n_rhs == 0:
25802559
# lapack can't handle n_rhs = 0 - so allocate
@@ -2583,7 +2562,8 @@ def lstsq(a, b, rcond=None):
25832562

25842563
with errstate(call=_raise_linalgerror_lstsq, invalid='call',
25852564
over='ignore', divide='ignore', under='ignore'):
2586-
x, resids, rank, s = gufunc(a, b, rcond, signature=signature)
2565+
x, resids, rank, s = _umath_linalg.lstsq(a, b, rcond,
2566+
signature=signature)
25872567
if m == 0:
25882568
x[...] = 0
25892569
if n_rhs == 0:

0 commit comments

Comments
 (0)