Skip to content

Commit b16189e

Browse files
Add numba dispatch for banded_dot
1 parent 1ddd529 commit b16189e

File tree

6 files changed

+199
-1
lines changed

6 files changed

+199
-1
lines changed

pytensor/link/numba/dispatch/basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def numba_njit(*args, fastmath=None, **kwargs):
7575
message=(
7676
"(\x1b\\[1m)*" # ansi escape code for bold text
7777
"Cannot cache compiled function "
78-
'"(numba_funcified_fgraph|store_core_outputs|cholesky|solve|solve_triangular|cho_solve|lu_factor)" '
78+
'"(numba_funcified_fgraph|store_core_outputs|cholesky|solve|solve_triangular|cho_solve|lu_factor|banded_dot)" '
7979
"as it uses dynamic globals"
8080
),
8181
category=NumbaWarning,
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
import ctypes
2+
3+
from numba.core.extending import get_cython_function_address
4+
from numba.np.linalg import ensure_blas, ensure_lapack, get_blas_kind
5+
6+
from pytensor.link.numba.dispatch.linalg._LAPACK import (
7+
_get_float_pointer_for_dtype,
8+
_ptr_int,
9+
)
10+
11+
12+
def _get_blas_ptr_and_ptr_type(dtype, name):
13+
d = get_blas_kind(dtype)
14+
func_name = f"{d}{name}"
15+
float_pointer = _get_float_pointer_for_dtype(d)
16+
lapack_ptr = get_cython_function_address("scipy.linalg.cython_blas", func_name)
17+
18+
return lapack_ptr, float_pointer
19+
20+
21+
class _BLAS:
22+
"""
23+
Functions to return type signatures for wrapped BLAS functions.
24+
25+
Here we are specifically concered with BLAS functions exposed by scipy, and not used by numpy.
26+
27+
Patterned after https://github.com/numba/numba/blob/bd7ebcfd4b850208b627a3f75d4706000be36275/numba/np/linalg.py#L74
28+
"""
29+
30+
def __init__(self):
31+
ensure_lapack()
32+
ensure_blas()
33+
34+
@classmethod
35+
def numba_xgbmv(cls, dtype):
36+
"""
37+
xGBMV performs one of the following matrix operations:
38+
39+
y = alpha * A @ x + beta * y, or y = alpha * A.T @ x + beta * y
40+
41+
Where alpha and beta are scalars, x and y are vectors, and A is a band matrix with kl sub-diagonals and ku
42+
super-diagonals.
43+
"""
44+
45+
blas_ptr, float_pointer = _get_blas_ptr_and_ptr_type(dtype, "gbmv")
46+
47+
functype = ctypes.CFUNCTYPE(
48+
None,
49+
_ptr_int, # TRANS
50+
_ptr_int, # M
51+
_ptr_int, # N
52+
_ptr_int, # KL
53+
_ptr_int, # KU
54+
float_pointer, # ALPHA
55+
float_pointer, # A
56+
_ptr_int, # LDA
57+
float_pointer, # X
58+
_ptr_int, # INCX
59+
float_pointer, # BETA
60+
float_pointer, # Y
61+
_ptr_int, # INCY
62+
)
63+
64+
return functype(blas_ptr)

pytensor/link/numba/dispatch/linalg/dot/__init__.py

Whitespace-only changes.
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
from collections.abc import Callable
2+
3+
import numpy as np
4+
from numba import njit as numba_njit
5+
from numba.core.extending import overload
6+
from numba.np.linalg import _copy_to_fortran_order, ensure_blas, ensure_lapack
7+
from scipy import linalg
8+
9+
from pytensor.link.numba.dispatch.linalg._BLAS import _BLAS
10+
from pytensor.link.numba.dispatch.linalg._LAPACK import (
11+
_get_underlying_float,
12+
val_to_int_ptr,
13+
)
14+
from pytensor.link.numba.dispatch.linalg.utils import _check_scipy_linalg_matrix
15+
16+
17+
@numba_njit(inline="always")
18+
def A_to_banded(A: np.ndarray, kl: int, ku: int) -> np.ndarray:
19+
m, n = A.shape
20+
A_banded = np.zeros((kl + ku + 1, n), dtype=A.dtype)
21+
22+
for i, k in enumerate(range(ku, -kl - 1, -1)):
23+
if k >= 0:
24+
A_banded[i, k:] = np.diag(A, k=k)
25+
else:
26+
A_banded[i, : n + k] = np.diag(A, k=k)
27+
28+
return A_banded
29+
30+
31+
def _dot_banded(A: np.ndarray, x: np.ndarray, kl: int, ku: int) -> np.ndarray:
32+
"""
33+
Thin wrapper around gmbv. This code will only be called if njit is disabled globally
34+
(e.g. during testing)
35+
"""
36+
fn = linalg.get_blas_funcs("gbmv", (A, x))
37+
m, n = A.shape
38+
A_banded = A_to_banded(A, kl=kl, ku=ku)
39+
40+
return fn(m=m, n=n, kl=kl, ku=ku, alpha=1, a=A_banded, x=x)
41+
42+
43+
@overload(_dot_banded)
44+
def dot_banded_impl(
45+
A: np.ndarray, x: np.ndarray, kl: int, ku: int
46+
) -> Callable[[np.ndarray, np.ndarray, int, int], np.ndarray]:
47+
ensure_lapack()
48+
ensure_blas()
49+
_check_scipy_linalg_matrix(A, "dot_banded")
50+
dtype = A.dtype
51+
w_type = _get_underlying_float(dtype)
52+
numba_gbmv = _BLAS().numba_xgbmv(dtype)
53+
54+
def impl(A: np.ndarray, x: np.ndarray, kl: int, ku: int) -> np.ndarray:
55+
m, n = A.shape
56+
57+
# TODO: Can we avoid this copy?
58+
A_banded = A_to_banded(A, kl=kl, ku=ku)
59+
A_banded = _copy_to_fortran_order(A_banded)
60+
61+
TRANS = val_to_int_ptr(ord("N"))
62+
M = val_to_int_ptr(m)
63+
N = val_to_int_ptr(n)
64+
LDA = val_to_int_ptr(A_banded.shape[0])
65+
66+
KL = val_to_int_ptr(kl)
67+
KU = val_to_int_ptr(ku)
68+
69+
ALPHA = np.array(1.0, dtype=dtype)
70+
INCX = val_to_int_ptr(1)
71+
BETA = np.array(0.0, dtype=dtype)
72+
Y = np.empty(m, dtype=dtype)
73+
INCY = val_to_int_ptr(1)
74+
75+
numba_gbmv(
76+
TRANS,
77+
M,
78+
N,
79+
KL,
80+
KU,
81+
ALPHA.view(w_type).ctypes,
82+
A_banded.view(w_type).ctypes,
83+
LDA,
84+
x.view(w_type).ctypes,
85+
INCX,
86+
BETA.view(w_type).ctypes,
87+
Y.view(w_type).ctypes,
88+
INCY,
89+
)
90+
91+
return Y
92+
93+
return impl

pytensor/link/numba/dispatch/slinalg.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
_pivot_to_permutation,
1212
)
1313
from pytensor.link.numba.dispatch.linalg.decomposition.lu_factor import _lu_factor
14+
from pytensor.link.numba.dispatch.linalg.dot.banded import _dot_banded
1415
from pytensor.link.numba.dispatch.linalg.solve.cholesky import _cho_solve
1516
from pytensor.link.numba.dispatch.linalg.solve.general import _solve_gen
1617
from pytensor.link.numba.dispatch.linalg.solve.posdef import _solve_psd
@@ -19,6 +20,7 @@
1920
from pytensor.link.numba.dispatch.linalg.solve.tridiagonal import _solve_tridiagonal
2021
from pytensor.tensor.slinalg import (
2122
LU,
23+
BandedDot,
2224
BlockDiagonal,
2325
Cholesky,
2426
CholeskySolve,
@@ -311,3 +313,19 @@ def cho_solve(c, b):
311313
)
312314

313315
return cho_solve
316+
317+
318+
@numba_funcify.register(BandedDot)
319+
def numba_funcify_BandedDot(op, node, **kwargs):
320+
kl = op.lower_diags
321+
ku = op.upper_diags
322+
dtype = node.inputs[0].dtype
323+
324+
if dtype in complex_dtypes:
325+
raise NotImplementedError(_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op=op))
326+
327+
@numba_njit
328+
def banded_dot(A, x):
329+
return _dot_banded(A, x, kl=kl, ku=ku)
330+
331+
return banded_dot

tests/link/numba/test_slinalg.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,10 @@
1515
LUFactor,
1616
Solve,
1717
SolveTriangular,
18+
banded_dot,
1819
)
1920
from tests.link.numba.test_basic import compare_numba_and_py, numba_inplace_mode
21+
from tests.tensor.test_slinalg import _make_banded_A
2022

2123

2224
pytestmark = pytest.mark.filterwarnings("error")
@@ -720,3 +722,24 @@ def test_lu_solve(b_func, b_shape: tuple[int, ...], trans: bool, overwrite_b: bo
720722

721723
# Can never destroy non-contiguous inputs
722724
np.testing.assert_allclose(b_val_not_contig, b_val)
725+
726+
727+
def test_banded_dot():
728+
rng = np.random.default_rng()
729+
730+
A_val = _make_banded_A(rng.normal(size=(10, 10)), kl=1, ku=1).astype(config.floatX)
731+
x_val = rng.normal(size=(10,)).astype(config.floatX)
732+
733+
A = pt.tensor("A", shape=A_val.shape, dtype=A_val.dtype)
734+
x = pt.tensor("x", shape=x_val.shape, dtype=x_val.dtype)
735+
736+
output = banded_dot(A, x, upper_diags=1, lower_diags=1)
737+
738+
compare_numba_and_py(
739+
[A, x],
740+
output,
741+
test_inputs=[A_val, x_val],
742+
inplace=True,
743+
numba_mode=numba_inplace_mode,
744+
eval_obj_mode=False,
745+
)

0 commit comments

Comments
 (0)