Skip to content

Commit c18f095

Browse files
Create A_banded as F-contiguous array
1 parent 7d109b9 commit c18f095

File tree

2 files changed

+7
-4
lines changed

2 files changed

+7
-4
lines changed

pytensor/link/numba/dispatch/linalg/dot/banded.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def _dot_banded(A: np.ndarray, x: np.ndarray, kl: int, ku: int) -> np.ndarray:
3838
"""
3939
fn = linalg.get_blas_funcs("gbmv", (A, x))
4040
m, n = A.shape
41-
A_banded = A_to_banded(A, kl=kl, ku=ku, order="C")
41+
A_banded = A_to_banded(A, kl=kl, ku=ku, order="F")
4242

4343
return fn(m=m, n=n, kl=kl, ku=ku, alpha=1, a=A_banded, x=x)
4444

pytensor/tensor/slinalg.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import numpy as np
88
import scipy.linalg as scipy_linalg
9-
from numpy import diag, zeros
9+
from numpy import zeros
1010
from numpy.exceptions import ComplexWarning
1111

1212
import pytensor
@@ -1697,10 +1697,13 @@ def perform(self, node, inputs, outputs_storage):
16971697
kl = self.lower_diags
16981698
ku = self.upper_diags
16991699

1700-
A_banded = zeros((kl + ku + 1, n), dtype=A.dtype, order="C")
1700+
A_banded = zeros((kl + ku + 1, n), dtype=A.dtype, order="F")
17011701

17021702
for i, k in enumerate(range(ku, -kl - 1, -1)):
1703-
A_banded[i, slice(k, None) if k >= 0 else slice(None, n + k)] = diag(A, k=k)
1703+
if k >= 0:
1704+
A_banded[i, k:] = np.diag(A, k=k)
1705+
else:
1706+
A_banded[i, : n + k] = np.diag(A, k=k)
17041707

17051708
fn = scipy_linalg.get_blas_funcs("gbmv", dtype=A.dtype)
17061709
outputs_storage[0][0] = fn(m=m, n=n, kl=kl, ku=ku, alpha=1, a=A_banded, x=x)

0 commit comments

Comments
 (0)