Skip to content

Commit 54a2c6e

Browse files
Respect overwrite_a and overwrite_b arguments
1 parent 86dd9cb commit 54a2c6e

File tree

3 files changed

+115
-37
lines changed

3 files changed

+115
-37
lines changed

pytensor/link/numba/dispatch/slinalg.py

Lines changed: 48 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ def impl(A, B, trans=0, lower=False, unit_diagonal=False):
107107
_solve_check_input_shapes(A, B)
108108

109109
B_is_1d = B.ndim == 1
110+
110111
if B_is_1d:
111112
B_copy = np.asfortranarray(np.expand_dims(B, -1))
112113
else:
@@ -387,7 +388,6 @@ def xgecon_impl(A, A_norm, norm):
387388

388389
def impl(A, A_norm, norm):
389390
_N = np.int32(A.shape[-1])
390-
A_copy = _copy_to_fortran_order(A)
391391

392392
N = val_to_int_ptr(_N)
393393
LDA = val_to_int_ptr(_N)
@@ -401,7 +401,7 @@ def impl(A, A_norm, norm):
401401
numba_gecon(
402402
NORM,
403403
N,
404-
A_copy.view(w_type).ctypes,
404+
A.view(w_type).ctypes,
405405
LDA,
406406
A_NORM.view(w_type).ctypes,
407407
RCOND.view(w_type).ctypes,
@@ -425,16 +425,20 @@ def _getrf():
425425

426426

427427
@overload(_getrf)
428-
def getrf_impl(A):
428+
def getrf_impl(A, overwrite_a=False):
429429
ensure_lapack()
430430
_check_scipy_linalg_matrix(A, "getrf")
431431
dtype = A.dtype
432432
w_type = _get_underlying_float(dtype)
433433
numba_getrf = _LAPACK().numba_xgetrf(dtype)
434434

435-
def impl(A):
435+
def impl(A, overwrite_a=False):
436436
_M, _N = np.int32(A.shape[-2:])
437-
A_copy = _copy_to_fortran_order(A)
437+
438+
if not overwrite_a:
439+
A_copy = _copy_to_fortran_order(A)
440+
else:
441+
A_copy = A
438442

439443
M = val_to_int_ptr(_M)
440444
N = val_to_int_ptr(_N)
@@ -459,23 +463,27 @@ def _getrs():
459463

460464

461465
@overload(_getrs)
462-
def getrs_impl(LU, B, IPIV, trans=0):
466+
def getrs_impl(LU, B, IPIV, trans=0, overwrite_b=False):
463467
ensure_lapack()
464468
_check_scipy_linalg_matrix(LU, "getrs")
465469
_check_scipy_linalg_matrix(B, "getrs")
466470
dtype = LU.dtype
467471
w_type = _get_underlying_float(dtype)
468472
numba_getrs = _LAPACK().numba_xgetrs(dtype)
469473

470-
def impl(LU, B, IPIV, trans=0):
474+
def impl(LU, B, IPIV, trans=0, overwrite_b=False):
471475
_N = np.int32(LU.shape[-1])
472476
_solve_check_input_shapes(LU, B)
473477

474478
B_is_1d = B.ndim == 1
475-
if B_is_1d:
476-
B_copy = np.asfortranarray(np.expand_dims(B, -1))
477-
else:
479+
480+
if not overwrite_b:
478481
B_copy = _copy_to_fortran_order(B)
482+
else:
483+
B_copy = B
484+
if B_is_1d:
485+
B_copy = np.asfortranarray(np.expand_dims(B_copy, -1))
486+
479487
B_NDIM = 1 if B_is_1d else int(B.shape[-1])
480488

481489
TRANS = val_to_int_ptr(_trans_char_to_int(trans))
@@ -591,13 +599,21 @@ def sysv_impl(A, B, lower=False, overwrite_a=False, overwrite_b=False):
591599
def impl(A, B, lower=False, overwrite_a=False, overwrite_b=False):
592600
_LDA, _N = np.int32(A.shape[-2:])
593601
_solve_check_input_shapes(A, B)
594-
A_copy = _copy_to_fortran_order(A)
595602

596-
B_is_1d = B.ndim == 1
597-
if B_is_1d:
598-
B_copy = np.asfortranarray(np.expand_dims(B, -1))
603+
if not overwrite_a:
604+
A_copy = _copy_to_fortran_order(A)
599605
else:
606+
A_copy = A
607+
608+
B_is_1d = B.ndim == 1
609+
610+
if not overwrite_b:
600611
B_copy = _copy_to_fortran_order(B)
612+
else:
613+
B_copy = B
614+
if B_is_1d:
615+
B_copy = np.asfortranarray(np.expand_dims(B_copy, -1))
616+
601617
B_NDIM = 1 if B_is_1d else int(B.shape[-1])
602618

603619
UPLO = val_to_int_ptr(ord("L") if lower else ord("U"))
@@ -790,13 +806,22 @@ def impl(
790806
_solve_check_input_shapes(A, B)
791807

792808
_N = np.int32(A.shape[-1])
793-
A_copy = _copy_to_fortran_order(A)
794809

795-
B_is_1d = B.ndim == 1
796-
if B_is_1d:
797-
B_copy = np.asfortranarray(np.expand_dims(B, -1))
810+
if not overwrite_a:
811+
A_copy = _copy_to_fortran_order(A)
798812
else:
813+
A_copy = A
814+
815+
B_is_1d = B.ndim == 1
816+
817+
if not overwrite_b:
799818
B_copy = _copy_to_fortran_order(B)
819+
else:
820+
B_copy = B
821+
if B_is_1d:
822+
B_copy = np.asfortranarray(np.expand_dims(B_copy, -1))
823+
824+
B_NDIM = 1 if B_is_1d else int(B.shape[-1])
800825

801826
UPLO = val_to_int_ptr(ord("L") if lower else ord("U"))
802827
B_NDIM = 1 if B_is_1d else int(B.shape[-1])
@@ -889,6 +914,11 @@ def numba_funcify_Solve(op, node, **kwargs):
889914
solve_fn = _solve_gen
890915
elif assume_a == "sym":
891916
solve_fn = _solve_symmetric
917+
elif assume_a == "her":
918+
raise NotImplementedError(
919+
'Use assume_a = "sym" for symmetric real matrices. If you need compelx support, '
920+
"please open an issue on github."
921+
)
892922
elif assume_a == "pos":
893923
solve_fn = _solve_psd
894924
else:

pytensor/tensor/slinalg.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -436,6 +436,7 @@ def solve_triangular(
436436
trans: int | str = 0,
437437
lower: bool = False,
438438
unit_diagonal: bool = False,
439+
overwrite_b: bool = False,
439440
check_finite: bool = True,
440441
b_ndim: int | None = None,
441442
) -> TensorVariable:
@@ -461,6 +462,8 @@ def solve_triangular(
461462
Whether to check that the input matrices contain only finite numbers.
462463
Disabling may give a performance gain, but may result in problems
463464
(crashes, non-termination) if the inputs do contain infinities or NaNs.
465+
overwrite_b: bool, optional
466+
If True, memory allocated to input B will be re-used for the output. Default is False.
464467
b_ndim : int
465468
Whether the core case of b is a vector (1) or matrix (2).
466469
This will influence how batched dimensions are interpreted.
@@ -472,6 +475,7 @@ def solve_triangular(
472475
trans=trans,
473476
unit_diagonal=unit_diagonal,
474477
check_finite=check_finite,
478+
overwrite_b=overwrite_b,
475479
b_ndim=b_ndim,
476480
)
477481
)(a, b)
@@ -537,6 +541,8 @@ def solve(
537541
lower=False,
538542
check_finite=True,
539543
transposed=False,
544+
overwrite_a=False,
545+
overwrite_b=False,
540546
b_ndim: int | None = None,
541547
):
542548
"""Solves the linear equation set ``a * x = b`` for the unknown ``x`` for square ``a`` matrix.
@@ -574,6 +580,10 @@ def solve(
574580
(crashes, non-termination) if the inputs do contain infinities or NaNs.
575581
assume_a : str, optional
576582
Valid entries are explained above.
583+
overwrite_a: bool, optional
584+
If True, use A as a work space to avoid allocating new memory. Default is False
585+
overwrite_b: bool, optional
586+
If True, use B to store result. Otherwise, allocate new memory. Default is False
577587
transposed: bool, optional
578588
If True, solve ``A.T @ x = b``
579589
b_ndim : int
@@ -588,6 +598,8 @@ def solve(
588598
assume_a=assume_a,
589599
b_ndim=b_ndim,
590600
transposed=transposed,
601+
overwrite_a=overwrite_a,
602+
overwrite_b=overwrite_b,
591603
)
592604
)(a, b)
593605

tests/link/numba/test_slinalg.py

Lines changed: 55 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@
22

33
import numpy as np
44
import pytest
5+
from numpy.testing import assert_allclose
56

67
import pytensor
78
import pytensor.tensor as pt
89
from pytensor.graph import FunctionGraph
910
from tests.link.numba.test_basic import compare_numba_and_py
10-
from tests.unittest_tools import assert_allclose
1111

1212

1313
numba = pytest.importorskip("numba")
@@ -42,7 +42,10 @@ def transpose_func(x, trans):
4242
@pytest.mark.filterwarnings(
4343
'ignore:Cannot cache compiled function "numba_funcified_fgraph"'
4444
)
45-
def test_solve_triangular(b_func, b_size, lower, trans, unit_diag, complex):
45+
@pytest.mark.parametrize("overwrite_b", [True, False])
46+
def test_solve_triangular(
47+
b_func, b_size, lower, trans, unit_diag, complex, overwrite_b
48+
):
4649
if complex:
4750
# TODO: Complex raises ValueError: To change to a dtype of a different size, the last axis must be contiguous,
4851
# why?
@@ -55,7 +58,7 @@ def test_solve_triangular(b_func, b_size, lower, trans, unit_diag, complex):
5558
b = b_func("b", dtype=dtype)
5659

5760
X = pt.linalg.solve_triangular(
58-
A, b, lower=lower, trans=trans, unit_diagonal=unit_diag
61+
A, b, lower=lower, trans=trans, unit_diagonal=unit_diag, overwrite_b=overwrite_b
5962
)
6063
f = pytensor.function([A, b], X, mode="NUMBA")
6164

@@ -84,6 +87,9 @@ def test_solve_triangular(b_func, b_size, lower, trans, unit_diag, complex):
8487
transpose_func(A_tri, trans) @ X_np, b, atol=ATOL, rtol=RTOL
8588
)
8689

90+
if overwrite_b:
91+
assert_allclose(X_np, b)
92+
8793

8894
@pytest.mark.parametrize("value", [np.nan, np.inf])
8995
@pytest.mark.filterwarnings(
@@ -235,31 +241,42 @@ def gecon(x, norm):
235241
np.testing.assert_allclose(rcond, rcond2)
236242

237243

238-
def test_getrf():
244+
@pytest.mark.parametrize("overwrite_a", [True, False])
245+
def test_getrf(overwrite_a):
239246
from scipy.linalg import lu_factor
240247

241248
from pytensor.link.numba.dispatch.slinalg import _getrf
242249

243250
# TODO: Refactor this test to use compare_numba_and_py after we implement lu_factor in pytensor
244251

245252
@numba.njit()
246-
def getrf(x):
247-
return _getrf(x)
253+
def getrf(x, overwrite_a):
254+
return _getrf(x, overwrite_a=overwrite_a)
248255

249256
x = np.random.normal(size=(5, 5)).astype(floatX)
250-
LU, IPIV, info = getrf(x)
251-
lu, ipiv = lu_factor(x)
257+
x = np.asfortranarray(
258+
x
259+
) # x needs to be fortran-contiguous going into getrf for the overwrite option to work
260+
261+
lu, ipiv = lu_factor(x, overwrite_a=False)
262+
LU, IPIV, info = getrf(x, overwrite_a=overwrite_a)
252263

253264
assert info == 0
254265
assert_allclose(LU, lu)
255266

267+
if overwrite_a:
268+
assert_allclose(x, LU)
269+
256270
# TODO: It seems IPIV is 1-indexed in FORTRAN, so we need to subtract 1. I can't find evidence that scipy is doing
257271
# this, though.
258272
assert_allclose(IPIV - 1, ipiv)
259273

260274

261275
@pytest.mark.parametrize("trans", [0, 1])
262-
def test_getrs(trans):
276+
@pytest.mark.parametrize("overwrite_a", [True, False])
277+
@pytest.mark.parametrize("overwrite_b", [True, False])
278+
@pytest.mark.parametrize("b_shape", [(5,), (5, 3)], ids=["b_1d", "b_2d"])
279+
def test_getrs(trans, overwrite_a, overwrite_b, b_shape):
263280
from scipy.linalg import lu_factor
264281
from scipy.linalg import lu_solve as sp_lu_solve
265282

@@ -268,19 +285,29 @@ def test_getrs(trans):
268285
# TODO: Refactor this test to use compare_numba_and_py after we implement lu_solve in pytensor
269286

270287
@numba.njit()
271-
def lu_solve(a, b, trans):
272-
lu, ipiv, info = _getrf(a)
273-
x, info = _getrs(lu, b, ipiv, trans)
274-
return x, info
288+
def lu_solve(a, b, trans, overwrite_a, overwrite_b):
289+
lu, ipiv, info = _getrf(a, overwrite_a=overwrite_a)
290+
x, info = _getrs(lu, b, ipiv, trans=trans, overwrite_b=overwrite_b)
291+
return x, lu, info
275292

276293
a = np.random.normal(size=(5, 5)).astype(floatX)
277-
b = np.random.normal(size=(5, 3)).astype(floatX)
294+
b = np.random.normal(size=b_shape).astype(floatX)
278295

279-
lu_and_piv = lu_factor(a)
296+
# inputs need to be fortran-contiguous going into getrf and getrs for the overwrite option to work
297+
a = np.asfortranarray(a)
298+
b = np.asfortranarray(b)
280299

281-
x_sp = sp_lu_solve(lu_and_piv, b, trans)
282-
x, info = lu_solve(a, b, trans)
300+
lu_and_piv = lu_factor(a, overwrite_a=False)
301+
x_sp = sp_lu_solve(lu_and_piv, b, trans, overwrite_b=False)
302+
303+
x, lu, info = lu_solve(
304+
a, b, trans, overwrite_a=overwrite_a, overwrite_b=overwrite_b
305+
)
283306
assert info == 0
307+
if overwrite_a:
308+
assert_allclose(a, lu)
309+
if overwrite_b:
310+
assert_allclose(b, x)
284311

285312
assert_allclose(x, x_sp)
286313

@@ -295,12 +322,21 @@ def lu_solve(a, b, trans):
295322
@pytest.mark.filterwarnings(
296323
'ignore:Cannot cache compiled function "numba_funcified_fgraph"'
297324
)
298-
def test_solve(b_func, b_size, assume_a, transposed):
325+
@pytest.mark.parametrize("overwrite_a", [True, False])
326+
@pytest.mark.parametrize("overwrite_b", [True, False])
327+
def test_solve(b_func, b_size, assume_a, transposed, overwrite_a, overwrite_b):
299328
A = pt.matrix("A", dtype=floatX)
300329
b = b_func("b", dtype=floatX)
301330

302331
X = pt.linalg.solve(
303-
A, b, lower=False, assume_a=assume_a, transposed=transposed, b_ndim=len(b_size)
332+
A,
333+
b,
334+
lower=False,
335+
assume_a=assume_a,
336+
overwrite_a=overwrite_a,
337+
overwrite_b=overwrite_b,
338+
transposed=transposed,
339+
b_ndim=len(b_size),
304340
)
305341
f = pytensor.function([A, b], X, mode="NUMBA")
306342

0 commit comments

Comments
 (0)