Skip to content

Commit fc8d4e7

Browse files
authored
Merge pull request scipy#23547 from ev-br/solve_sy_he
`linalg.solve`: move "sym"/"her" solvers to C
2 parents 7bbabe5 + 37129b3 commit fc8d4e7

File tree

7 files changed

+542
-57
lines changed

7 files changed

+542
-57
lines changed

benchmarks/benchmarks/linalg.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,36 @@ def time_svd(self, size, contig, module):
102102
)
103103

104104

105+
class BatchedSolveBench(Benchmark):
106+
params = [
107+
[(100, 10, 10), (100, 20, 20), (100, 100)],
108+
["gen", "pos", "sym"],
109+
["scipy", "numpy"]
110+
]
111+
param_names = ["shape", "structure" ,"module"]
112+
113+
def setup(self, shape, structure, module):
114+
a = random(shape)
115+
# larger diagonal ensures non-singularity:
116+
for i in range(shape[-1]):
117+
a[..., i, i] = 10*(.1+a[..., i, i])
118+
119+
if structure == "pos":
120+
self.a = a @ a.mT
121+
elif structure == "sym":
122+
self.a = a + a.mT
123+
else:
124+
self.a = a
125+
126+
self.b = random([a.shape[-1]])
127+
128+
def time_solve(self, shape, structure, module):
129+
if module == 'numpy':
130+
nl.solve(self.a, self.b)
131+
else:
132+
sl.solve(self.a, self.b, assume_a=structure)
133+
134+
105135
class Norm(Benchmark):
106136
params = [
107137
[(20, 20), (100, 100), (1000, 1000), (20, 1000), (1000, 20)],

scipy/linalg/_basic.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ def solve(a, b, lower=False, overwrite_a=False,
209209
[ 5. , -4.5]])
210210
"""
211211
if assume_a in [
212-
'sym', 'her', 'symmetric', 'hermitian', 'diagonal', 'tridiagonal', 'banded'
212+
'diagonal', 'tridiagonal', 'banded'
213213
]:
214214
# TODO: handle these structures in this function
215215
return solve0(
@@ -225,6 +225,8 @@ def solve(a, b, lower=False, overwrite_a=False,
225225
'upper triangular': 21,
226226
'lower triangular': 22,
227227
'pos' : 101, 'positive definite': 101,
228+
'sym' : 201, 'symmetric': 201,
229+
'her' : 211, 'hermitian': 211,
228230
}.get(assume_a, 'unknown')
229231
if structure == 'unknown':
230232
raise ValueError(f'{assume_a} is not a recognized matrix structure')
@@ -1349,10 +1351,12 @@ def inv(a, overwrite_a=False, check_finite=True, assume_a=None, lower=False):
13491351
upper triangular 'upper triangular'
13501352
lower triangular 'lower triangular'
13511353
symmetric positive definite 'pos'
1354+
symmetric 'sym'
1355+
Hermitian 'her'
13521356
============================= ================================
13531357
1354-
For the 'pos' option, only the triangle of the input matrix specified in
1355-
the `lower` argument is used, and the other triangle is not referenced.
1358+
For the 'pos', 'sym' and 'her' options, only the specified triangle of the input
1359+
matrix is used, and the other triangle is not referenced.
13561360
13571361
Array argument(s) of this function may have additional
13581362
"batch" dimensions prepended to the core shape. In this case, the array is treated
@@ -1432,14 +1436,16 @@ def inv(a, overwrite_a=False, check_finite=True, assume_a=None, lower=False):
14321436
overwrite_a = True
14331437
a1 = a1.copy()
14341438

1435-
# keep the numbers in sync with C
1439+
# keep the numbers in sync with C at `linalg/src/_common_array_utils.hh`
14361440
structure = {
14371441
None: -1,
14381442
'general': 0,
14391443
# 'diagonal': 11,
14401444
'upper triangular': 21,
14411445
'lower triangular': 22,
14421446
'pos' : 101,
1447+
'sym' : 201,
1448+
'her' : 211,
14431449
}[assume_a]
14441450

14451451
# a1 is well behaved, invert it.

scipy/linalg/src/_batched_linalg_module.cc

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
#include "_common_array_utils.hh"
77

88

9-
#define PYERR(errobj,message) {PyErr_SetString(errobj,message); return NULL;}
109
static PyObject* _linalg_inv_error;
1110

1211

@@ -45,7 +44,8 @@ _linalg_inv(PyObject* Py_UNUSED(dummy), PyObject* args) {
4544
npy_intp* shape = PyArray_SHAPE(ap_Am); // Array shape
4645
npy_intp n = shape[ndim - 1]; // Slice size
4746
if (n != shape[ndim - 2]) {
48-
PYERR(PyExc_ValueError, "Last two dimensions of the input must be the same.")
47+
PyErr_SetString(PyExc_ValueError, "Last two dimensions of the input must be the same.");
48+
return NULL;
4949
}
5050

5151
overwrite_a = 0; // TODO: enable it
@@ -79,13 +79,15 @@ _linalg_inv(PyObject* Py_UNUSED(dummy), PyObject* args) {
7979
info = _inverse<npy_complex128>(ap_Am, (npy_complex128 *)buf, structure, lower, overwrite_a, vec_status);
8080
break;
8181
default:
82-
PYERR(PyExc_RuntimeError, "Unknown array type.")
82+
PyErr_SetString(PyExc_RuntimeError, "Unknown array type.");
83+
return NULL;
8384
}
8485

8586
if (info < 0) {
8687
// Either OOM or internal LAPACK error.
8788
Py_DECREF(ap_Ainv);
88-
PYERR(PyExc_RuntimeError, "Memory error in scipy.linalg.inv.")
89+
PyErr_SetString(PyExc_RuntimeError, "Memory error in scipy.linalg.inv.");
90+
return NULL;
8991
}
9092
PyObject *ret_lst = convert_vec_status(vec_status);
9193

@@ -127,7 +129,8 @@ _linalg_solve(PyObject* Py_UNUSED(dummy), PyObject* args) {
127129
int ndim = PyArray_NDIM(ap_Am);
128130
npy_intp* shape = PyArray_SHAPE(ap_Am);
129131
if ((ndim < 2) || (shape[ndim - 1] != shape[ndim - 2])) {
130-
PYERR(PyExc_ValueError, "Last two dimensions of `a` must be the same.")
132+
PyErr_SetString(PyExc_ValueError, "Last two dimensions of `a` must be the same.");
133+
return NULL;
131134
}
132135

133136
// At the python call site,
@@ -171,13 +174,15 @@ _linalg_solve(PyObject* Py_UNUSED(dummy), PyObject* args) {
171174
info = _solve<npy_complex128>(ap_Am, ap_b, (npy_complex128 *)buf, structure, lower, transposed, overwrite_a, vec_status);
172175
break;
173176
default:
174-
PYERR(PyExc_RuntimeError, "Unknown array type.")
177+
PyErr_SetString(PyExc_RuntimeError, "Unknown array type.");
178+
return NULL;
175179
}
176180

177181
if (info < 0) {
178-
// Either OOM or internal LAPACK error.
182+
// Either OOM error or requiested lwork too large.
179183
Py_DECREF(ap_x);
180-
PYERR(PyExc_RuntimeError, "Memory error in scipy.linalg.solve.")
184+
PyErr_SetString(PyExc_MemoryError, "Memory error in scipy.linalg.solve.");
185+
return NULL;
181186
}
182187
PyObject *ret_lst = convert_vec_status(vec_status);
183188

0 commit comments

Comments
 (0)