Skip to content

Commit 60bfb3b

Browse files
committed
Speedup solve tridiagonal
1 parent 674962e commit 60bfb3b

File tree

1 file changed

+18
-9
lines changed

1 file changed

+18
-9
lines changed

pytensor/tensor/slinalg.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -937,15 +937,24 @@ def __init__(self, *, assume_a="gen", **kwargs):
937937

938938
def perform(self, node, inputs, outputs):
939939
a, b = inputs
940-
outputs[0][0] = scipy_linalg.solve(
941-
a=a,
942-
b=b,
943-
lower=self.lower,
944-
check_finite=self.check_finite,
945-
assume_a=self.assume_a,
946-
overwrite_a=self.overwrite_a,
947-
overwrite_b=self.overwrite_b,
948-
)
940+
if self.assume_a == "tridiagonal":
941+
[dl, d, du] = (a.diagonal(offset=o) for o in (-1, 0, 1))
942+
_gttrf, _gttrs = scipy_linalg.get_lapack_funcs(
943+
("gttrf", "gttrs"), dtype=node.outputs[0].type.dtype
944+
)
945+
dl, d, du, du2, ipiv, _ = _gttrf(dl, d, du)
946+
x, _ = _gttrs(dl, d, du, du2, ipiv, b, overwrite_b=self.overwrite_b)
947+
outputs[0][0] = x
948+
else:
949+
outputs[0][0] = scipy_linalg.solve(
950+
a=a,
951+
b=b,
952+
lower=self.lower,
953+
check_finite=self.check_finite,
954+
assume_a=self.assume_a,
955+
overwrite_a=self.overwrite_a,
956+
overwrite_b=self.overwrite_b,
957+
)
949958

950959
def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op":
951960
if not allowed_inplace_inputs:

0 commit comments

Comments
 (0)