1212 _get_underlying_float ,
1313 val_to_int_ptr ,
1414)
15- from pytensor .link .numba .dispatch .linalg .utils import _check_scipy_linalg_matrix
15+ from pytensor .link .numba .dispatch .linalg .utils import (
16+ _check_scipy_linalg_matrix ,
17+ _copy_to_fortran_order_even_if_1d ,
18+ _trans_char_to_int ,
19+ )
1620
1721
1822@numba_njit (inline = "always" )
@@ -32,69 +36,140 @@ def A_to_banded(A: np.ndarray, kl: int, ku: int) -> np.ndarray:
3236 return A_banded
3337
3438
35- def _dot_banded (A : np .ndarray , x : np .ndarray , kl : int , ku : int ) -> Any :
39+ def _gbmv (
40+ alpha : np .ndarray ,
41+ A : np .ndarray ,
42+ x : np .ndarray ,
43+ kl : int ,
44+ ku : int ,
45+ beta : np .ndarray | None = None ,
46+ y : np .ndarray | None = None ,
47+ overwrite_y : bool = False ,
48+ trans : int = 1 ,
49+ ) -> Any :
3650 """
3751 Thin wrapper around gmbv. This code will only be called if njit is disabled globally
3852 (e.g. during testing)
3953 """
40- fn = linalg .get_blas_funcs ("gbmv" , (A , x ))
54+ ( fn ,) = linalg .get_blas_funcs (( "gbmv" ,) , (A , x ))
4155 m , n = A .shape
4256 A_banded = A_to_banded (A , kl = kl , ku = ku )
4357
44- return fn (m = m , n = n , kl = kl , ku = ku , alpha = 1 , a = A_banded , x = x )
45-
46-
47- @overload (_dot_banded )
48- def dot_banded_impl (
49- A : np .ndarray , x : np .ndarray , kl : int , ku : int
50- ) -> Callable [[np .ndarray , np .ndarray , int , int ], np .ndarray ]:
58+ incx = x .strides [0 ] // x .itemsize
59+ incy = y .strides [0 ] // y .itemsize if y is not None else 1
60+
61+ offx = 0 if incx >= 0 else - x .size + 1
62+ offy = 0 if incy >= 0 else - y .size + 1
63+
64+ return fn (
65+ m = m ,
66+ n = n ,
67+ kl = kl ,
68+ ku = ku ,
69+ a = A_banded ,
70+ alpha = alpha ,
71+ x = x ,
72+ incx = incx ,
73+ offx = offx ,
74+ beta = beta ,
75+ y = y ,
76+ overwrite_y = overwrite_y ,
77+ incy = incy ,
78+ offy = offy ,
79+ trans = trans ,
80+ )
81+
82+
83+ @overload (_gbmv )
84+ def gbmv_impl (
85+ alpha : np .ndarray ,
86+ A : np .ndarray ,
87+ x : np .ndarray ,
88+ kl : int ,
89+ ku : int ,
90+ beta : np .ndarray | None = None ,
91+ y : np .ndarray | None = None ,
92+ overwrite_y : bool = False ,
93+ trans : int = 1 ,
94+ ) -> Callable [
95+ [
96+ np .ndarray ,
97+ np .ndarray ,
98+ np .ndarray ,
99+ int ,
100+ int ,
101+ np .ndarray | None ,
102+ np .ndarray | None ,
103+ bool ,
104+ int ,
105+ ],
106+ np .ndarray ,
107+ ]:
51108 ensure_lapack ()
52109 ensure_blas ()
53110 _check_scipy_linalg_matrix (A , "dot_banded" )
54111 dtype = A .dtype
55112 w_type = _get_underlying_float (dtype )
56113 numba_gbmv = _BLAS ().numba_xgbmv (dtype )
57114
58- def impl (A : np .ndarray , x : np .ndarray , kl : int , ku : int ) -> np .ndarray :
115+ def impl (
116+ alpha : np .ndarray ,
117+ A : np .ndarray ,
118+ x : np .ndarray ,
119+ kl : int ,
120+ ku : int ,
121+ beta : np .ndarray | None = None ,
122+ y : np .ndarray | None = None ,
123+ overwrite_y : bool = False ,
124+ trans : int = 1 ,
125+ ) -> np .ndarray :
59126 m , n = A .shape
60127
61128 A_banded = A_to_banded (A , kl = kl , ku = ku )
62- stride = x .strides [0 ] // x .itemsize
129+ x_stride = x .strides [0 ] // x .itemsize
130+
131+ if beta is None :
132+ beta = np .zeros ((), dtype = dtype )
63133
64- TRANS = val_to_int_ptr (ord ("N" ))
134+ if y is None :
135+ y_copy = np .empty (shape = (m ,), dtype = dtype )
136+ elif overwrite_y and y .flags .f_contiguous :
137+ y_copy = y
138+ else :
139+ y_copy = _copy_to_fortran_order_even_if_1d (y )
140+
141+ y_stride = y_copy .strides [0 ] // y_copy .itemsize
142+
143+ TRANS = val_to_int_ptr (_trans_char_to_int (trans ))
65144 M = val_to_int_ptr (m )
66145 N = val_to_int_ptr (n )
67146 LDA = val_to_int_ptr (A_banded .shape [0 ])
68147
69148 KL = val_to_int_ptr (kl )
70149 KU = val_to_int_ptr (ku )
71150
72- ALPHA = np .array (1.0 , dtype = dtype )
73-
74- INCX = val_to_int_ptr (stride )
75- BETA = np .array (0.0 , dtype = dtype )
76- Y = np .empty (m , dtype = dtype )
77- INCY = val_to_int_ptr (1 )
151+ INCX = val_to_int_ptr (x_stride )
152+ INCY = val_to_int_ptr (y_stride )
78153
79154 numba_gbmv (
80155 TRANS ,
81156 M ,
82157 N ,
83158 KL ,
84159 KU ,
85- ALPHA .view (w_type ).ctypes ,
160+ alpha .view (w_type ).ctypes ,
86161 A_banded .view (w_type ).ctypes ,
87162 LDA ,
88163 # x.view().ctypes is creating a pointer to the beginning of the memory where the array is. When we have
89164 # a negative stride, we need to trick BLAS by pointing to the last element of the array.
90165 # The [-1:] slice is a workaround to make sure x remains an array (otherwise it has no .ctypes)
91- (x if stride >= 0 else x [- 1 :]).view (w_type ).ctypes ,
166+ (x if x_stride >= 0 else x [- 1 :]).view (w_type ).ctypes ,
92167 INCX ,
93- BETA .view (w_type ).ctypes ,
94- Y .view (w_type ).ctypes ,
168+ beta .view (w_type ).ctypes ,
169+ y_copy .view (w_type ).ctypes ,
95170 INCY ,
96171 )
97172
98- return Y
173+ return y_copy
99174
100175 return impl
0 commit comments