Skip to content

Commit 161e172

Browse files
micro-optimizations
1 parent 0ce2cae commit 161e172

File tree

1 file changed

+7
-37
lines changed

1 file changed

+7
-37
lines changed

pytensor/tensor/slinalg.py

Lines changed: 7 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import numpy as np
88
import scipy.linalg as scipy_linalg
9+
from numpy import diag, zeros
910
from numpy.exceptions import ComplexWarning
1011

1112
import pytensor
@@ -1669,39 +1670,6 @@ def block_diag(*matrices: TensorVariable):
16691670
return _block_diagonal_matrix(*matrices)
16701671

16711672

1672-
def _to_banded_form(A, kl, ku):
1673-
"""
1674-
Convert a full matrix A to LAPACK banded form for gbmv.
1675-
1676-
Parameters
1677-
----------
1678-
A: np.ndarray
1679-
(m, n) banded matrix with nonzero values on the diagonals
1680-
kl: int
1681-
Number of nonzero lower diagonals of A
1682-
ku: int
1683-
Number of nonzero upper diagonals of A
1684-
1685-
Returns
1686-
-------
1687-
ab: np.ndarray
1688-
(kl + ku + 1, n) banded matrix suitable for LAPACK
1689-
"""
1690-
A = np.asarray(A)
1691-
m, n = A.shape
1692-
ab = np.zeros((kl + ku + 1, n), dtype=A.dtype, order="C")
1693-
1694-
for i, k in enumerate(range(ku, -kl - 1, -1)):
1695-
col_slice = slice(k, None) if k >= 0 else slice(None, n + k)
1696-
ab[i, col_slice] = np.diag(A, k=k)
1697-
1698-
return ab
1699-
1700-
1701-
_dgbmv = scipy_linalg.get_blas_funcs("gbmv", dtype="float64")
1702-
_sgbmv = scipy_linalg.get_blas_funcs("gbmv", dtype="float32")
1703-
1704-
17051673
class BandedDot(Op):
17061674
__props__ = ("lower_diags", "upper_diags")
17071675
gufunc_signature = "(m,n),(n)->(m)"
@@ -1725,15 +1693,17 @@ def infer_shape(self, fgraph, nodes, shapes):
17251693
def perform(self, node, inputs, outputs_storage):
17261694
A, b = inputs
17271695
m, n = A.shape
1728-
alpha = 1
17291696

17301697
kl = self.lower_diags
17311698
ku = self.upper_diags
17321699

1733-
A_banded = _to_banded_form(A, kl, ku)
1700+
A_banded = zeros((kl + ku + 1, n), dtype=A.dtype, order="C")
1701+
1702+
for i, k in enumerate(range(ku, -kl - 1, -1)):
1703+
A_banded[i, slice(k, None) if k >= 0 else slice(None, n + k)] = diag(A, k=k)
17341704

1735-
fn = _dgbmv if A.dtype == "float64" else _sgbmv
1736-
outputs_storage[0][0] = fn(m, n, kl, ku, alpha, A_banded, b)
1705+
fn = scipy_linalg.get_blas_funcs("gbmv", dtype=A.dtype)
1706+
outputs_storage[0][0] = fn(m=m, n=n, kl=kl, ku=ku, alpha=1, a=A_banded, x=b)
17371707

17381708

17391709
def banded_dot(A: TensorLike, b: TensorLike, lower_diags: int, upper_diags: int):

0 commit comments

Comments
 (0)