6
6
7
7
import numpy as np
8
8
import scipy .linalg as scipy_linalg
9
+ from numpy import diag , zeros
9
10
from numpy .exceptions import ComplexWarning
10
11
11
12
import pytensor
@@ -1669,39 +1670,6 @@ def block_diag(*matrices: TensorVariable):
1669
1670
return _block_diagonal_matrix (* matrices )
1670
1671
1671
1672
1672
- def _to_banded_form (A , kl , ku ):
1673
- """
1674
- Convert a full matrix A to LAPACK banded form for gbmv.
1675
-
1676
- Parameters
1677
- ----------
1678
- A: np.ndarray
1679
- (m, n) banded matrix with nonzero values on the diagonals
1680
- kl: int
1681
- Number of nonzero lower diagonals of A
1682
- ku: int
1683
- Number of nonzero upper diagonals of A
1684
-
1685
- Returns
1686
- -------
1687
- ab: np.ndarray
1688
- (kl + ku + 1, n) banded matrix suitable for LAPACK
1689
- """
1690
- A = np .asarray (A )
1691
- m , n = A .shape
1692
- ab = np .zeros ((kl + ku + 1 , n ), dtype = A .dtype , order = "C" )
1693
-
1694
- for i , k in enumerate (range (ku , - kl - 1 , - 1 )):
1695
- col_slice = slice (k , None ) if k >= 0 else slice (None , n + k )
1696
- ab [i , col_slice ] = np .diag (A , k = k )
1697
-
1698
- return ab
1699
-
1700
-
1701
- _dgbmv = scipy_linalg .get_blas_funcs ("gbmv" , dtype = "float64" )
1702
- _sgbmv = scipy_linalg .get_blas_funcs ("gbmv" , dtype = "float32" )
1703
-
1704
-
1705
1673
class BandedDot (Op ):
1706
1674
__props__ = ("lower_diags" , "upper_diags" )
1707
1675
gufunc_signature = "(m,n),(n)->(m)"
@@ -1725,15 +1693,17 @@ def infer_shape(self, fgraph, nodes, shapes):
1725
1693
def perform (self , node , inputs , outputs_storage ):
1726
1694
A , b = inputs
1727
1695
m , n = A .shape
1728
- alpha = 1
1729
1696
1730
1697
kl = self .lower_diags
1731
1698
ku = self .upper_diags
1732
1699
1733
- A_banded = _to_banded_form (A , kl , ku )
1700
+ A_banded = zeros ((kl + ku + 1 , n ), dtype = A .dtype , order = "C" )
1701
+
1702
+ for i , k in enumerate (range (ku , - kl - 1 , - 1 )):
1703
+ A_banded [i , slice (k , None ) if k >= 0 else slice (None , n + k )] = diag (A , k = k )
1734
1704
1735
- fn = _dgbmv if A . dtype == "float64" else _sgbmv
1736
- outputs_storage [0 ][0 ] = fn (m , n , kl , ku , alpha , A_banded , b )
1705
+ fn = scipy_linalg . get_blas_funcs ( "gbmv" , dtype = A . dtype )
1706
+ outputs_storage [0 ][0 ] = fn (m = m , n = n , kl = kl , ku = ku , alpha = 1 , a = A_banded , x = b )
1737
1707
1738
1708
1739
1709
def banded_dot (A : TensorLike , b : TensorLike , lower_diags : int , upper_diags : int ):
0 commit comments