Skip to content

Commit 669b4dc

Browse files
benchmarking -- do not merge
1 parent 94d3aa3 commit 669b4dc

File tree

2 files changed

+59
-35
lines changed

2 files changed

+59
-35
lines changed

pytensor/tensor/slinalg.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -738,12 +738,20 @@ class LUSolve(Op):
738738
Solve a system of linear equations given the LU factorization of the matrix.
739739
"""
740740

741-
__props__ = ("trans", "overwrite_b", "check_finite", "b_ndim")
741+
__props__ = ("trans", "overwrite_b", "check_finite", "b_ndim", "expect_pivots")
742742

743-
def __init__(self, b_ndim, trans=False, overwrite_b=False, check_finite=True):
743+
def __init__(
744+
self,
745+
b_ndim,
746+
trans=False,
747+
overwrite_b=False,
748+
check_finite=True,
749+
expect_pivots=False,
750+
):
744751
self.trans = trans
745752
self.overwrite_b = overwrite_b
746753
self.check_finite = check_finite
754+
self.expect_pivots = expect_pivots
747755

748756
assert b_ndim in (1, 2)
749757
self.b_ndim = b_ndim
@@ -789,6 +797,9 @@ def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op":
789797
def perform(self, node, inputs, outputs):
790798
LU, pivots, b = inputs
791799

800+
if not self.expect_pivots:
801+
raise NotImplementedError
802+
792803
outputs[0][0] = scipy_linalg.lu_solve(
793804
lu_and_piv=(LU, pivots),
794805
b=b,
@@ -807,8 +818,12 @@ def L_op(
807818
[x] = outputs
808819
[x_bar] = output_grads
809820

810-
p_inv = _pivot_to_permutation(pivots)
811-
p = pt.argsort(p_inv)
821+
if not self.expect_pivots:
822+
p_inv = _pivot_to_permutation(pivots)
823+
p = pt.argsort(p_inv)
824+
else:
825+
p = pivots
826+
812827
P = ptb.identity_like(LU)[p]
813828

814829
# We are solving PLUx = b

tests/tensor/test_slinalg.py

Lines changed: 40 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from pytensor.tensor.slinalg import (
1616
Cholesky,
1717
CholeskySolve,
18+
LUSolve,
1819
Solve,
1920
SolveBase,
2021
SolveTriangular,
@@ -703,7 +704,8 @@ def test_lu_factor(permutation_indices):
703704

704705
@pytest.mark.parametrize("b_shape", [(5,), (5, 5)])
705706
@pytest.mark.parametrize("trans", [True, False])
706-
def test_lu_solve(b_shape: tuple[int], trans):
707+
@pytest.mark.parametrize("use_op", [True, False])
708+
def test_lu_solve(b_shape: tuple[int], trans, use_op):
707709
def T(x):
708710
if trans:
709711
return x.T
@@ -717,7 +719,13 @@ def T(x):
717719
b_val = rng.normal(size=b_shape).astype(config.floatX)
718720

719721
LU_and_pivots = lu_factor(A)
720-
x = lu_solve(LU_and_pivots, b, trans=trans)
722+
723+
if use_op:
724+
x = LUSolve(b_ndim=len(b_shape), trans=trans, check_finite=False)(
725+
LU_and_pivots, b
726+
)
727+
else:
728+
x = lu_solve(LU_and_pivots, b, trans=trans)
721729

722730
f = pytensor.function([A, b], x)
723731
x_pt = f(A_val.copy(), b_val.copy())
@@ -735,26 +743,6 @@ def T(x):
735743
)
736744
np.testing.assert_allclose(x_pt, x_sp)
737745

738-
# import jax
739-
# import jax.scipy as jsp
740-
#
741-
# def jax_f(A, b):
742-
# LU_and_pivots = jsp.linalg.lu_factor(A)
743-
# x = jsp.linalg.lu_solve(LU_and_pivots, b, trans=trans)
744-
# return x.sum()
745-
746-
# jax_res = jax.value_and_grad(jax_f, [0, 1])(A_val, b_val)
747-
# g = grad(x.sum(), [A, b])
748-
# fg = pytensor.function([A, b], [x.sum(), *g])
749-
750-
# for a, b in zip(fg(A_val, b_val), [jax_res[0], *jax_res[1]]):
751-
# print(a - b)
752-
753-
# LU, pivots = pt.tensor('LU', shape=(5, 5)), pt.tensor('pivots', shape=(5,), dtype='int')
754-
# x = lu_solve((LU, pivots), b)
755-
756-
# LU_val, pivots_val = scipy.linalg.lu_factor(A_val)
757-
758746
utt.verify_grad(
759747
lambda A, b: lu_solve(lu_factor(A), b, trans=trans).sum(),
760748
pt=[A_val.copy(), b_val.copy()],
@@ -776,15 +764,6 @@ def test_fn(A, b):
776764
x = lu_solve(lu_and_pivots, b)
777765
return x.sum()
778766

779-
# A = pt.tensor("A", shape=(5, 5))
780-
# b = pt.tensor("b", shape=b_shape)
781-
782-
# fg = pytensor.function([A, b], grad(test_fn(A, b), [A, b]))
783-
# fg2 = pytensor.function([A, b], grad(pt.linalg.solve(A, b).sum(), [A, b]))
784-
785-
# print(fg(A_val, b_val))
786-
# print(fg2(A_val, b_val))
787-
788767
utt.verify_grad(test_fn, [A_val, b_val], 3, rng)
789768

790769

@@ -1065,3 +1044,33 @@ def test_block_diagonal_blockwise():
10651044
B = np.random.normal(size=(1, batch_size, 4, 4)).astype(config.floatX)
10661045
result = block_diag(A, B).eval()
10671046
assert result.shape == (10, batch_size, 6, 6)
1047+
1048+
1049+
def lu_solve_1(A, b):
1050+
lu, pivots = pt.linalg.lu_factor(A)
1051+
return pt.linalg.lu_solve((lu, pivots), b)
1052+
1053+
1054+
def lu_solve_2(A, b, b_ndim=1, trans=0, check_finite=False):
1055+
lu, pivots = pt.linalg.lu_factor(A)
1056+
return LUSolve(b_ndim=1, trans=0, check_finite=False)(lu, pivots, b)
1057+
1058+
1059+
@pytest.mark.parametrize(
1060+
"op", [lu_solve_1, lu_solve_2, pt.linalg.solve], ids=["lu_1", "lu_2", "solve"]
1061+
)
1062+
@pytest.mark.parametrize("n", [500])
1063+
def test_solve_methods(op, n, benchmark):
1064+
A = pt.tensor("A", shape=(n, n))
1065+
b = pt.tensor("b", shape=(n,))
1066+
1067+
x = op(A, b)
1068+
gx = pt.grad(x.sum(), [A, b])
1069+
f = pytensor.function([A, b], [x, *gx])
1070+
1071+
A_val = np.random.normal(size=(n, n)).astype(config.floatX)
1072+
b_val = np.random.normal(size=(n,)).astype(config.floatX)
1073+
1074+
# Trigger compilation if we're a jit mode
1075+
f(A_val, b_val)
1076+
benchmark(f, A_val, b_val)

0 commit comments

Comments
 (0)