3
3
import numpy as np
4
4
from numba import njit as numba_njit
5
5
from numba .core .extending import overload
6
- from numba .np .linalg import _copy_to_fortran_order , ensure_blas , ensure_lapack
6
+ from numba .np .linalg import ensure_blas , ensure_lapack
7
7
from scipy import linalg
8
8
9
9
from pytensor .link .numba .dispatch .linalg ._BLAS import _BLAS
15
15
16
16
17
17
@numba_njit (inline = "always" )
18
- def A_to_banded (A : np .ndarray , kl : int , ku : int ) -> np .ndarray :
18
+ def A_to_banded (A : np .ndarray , kl : int , ku : int , order = "C" ) -> np .ndarray :
19
19
m , n = A .shape
20
- A_banded = np .zeros ((kl + ku + 1 , n ), dtype = A .dtype )
20
+ if order == "C" :
21
+ A_banded = np .zeros ((kl + ku + 1 , n ), dtype = A .dtype )
22
+ else :
23
+ A_banded = np .zeros ((n , kl + ku + 1 ), dtype = A .dtype ).T
21
24
22
25
for i , k in enumerate (range (ku , - kl - 1 , - 1 )):
23
26
if k >= 0 :
@@ -35,7 +38,7 @@ def _dot_banded(A: np.ndarray, x: np.ndarray, kl: int, ku: int) -> np.ndarray:
35
38
"""
36
39
fn = linalg .get_blas_funcs ("gbmv" , (A , x ))
37
40
m , n = A .shape
38
- A_banded = A_to_banded (A , kl = kl , ku = ku )
41
+ A_banded = A_to_banded (A , kl = kl , ku = ku , order = "C" )
39
42
40
43
return fn (m = m , n = n , kl = kl , ku = ku , alpha = 1 , a = A_banded , x = x )
41
44
@@ -54,9 +57,7 @@ def dot_banded_impl(
54
57
def impl (A : np .ndarray , x : np .ndarray , kl : int , ku : int ) -> np .ndarray :
55
58
m , n = A .shape
56
59
57
- # TODO: Can we avoid this copy?
58
- A_banded = A_to_banded (A , kl = kl , ku = ku )
59
- A_banded = _copy_to_fortran_order (A_banded )
60
+ A_banded = A_to_banded (A , kl = kl , ku = ku , order = "F" )
60
61
61
62
TRANS = val_to_int_ptr (ord ("N" ))
62
63
M = val_to_int_ptr (m )
0 commit comments