Skip to content

Commit a902694

Browse files
Eliminate extra copy in numba impl
1 parent b16189e commit a902694

File tree

1 file changed

+8
-7
lines changed
  • pytensor/link/numba/dispatch/linalg/dot

1 file changed

+8
-7
lines changed

pytensor/link/numba/dispatch/linalg/dot/banded.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import numpy as np
44
from numba import njit as numba_njit
55
from numba.core.extending import overload
6-
from numba.np.linalg import _copy_to_fortran_order, ensure_blas, ensure_lapack
6+
from numba.np.linalg import ensure_blas, ensure_lapack
77
from scipy import linalg
88

99
from pytensor.link.numba.dispatch.linalg._BLAS import _BLAS
@@ -15,9 +15,12 @@
1515

1616

1717
@numba_njit(inline="always")
18-
def A_to_banded(A: np.ndarray, kl: int, ku: int) -> np.ndarray:
18+
def A_to_banded(A: np.ndarray, kl: int, ku: int, order="C") -> np.ndarray:
1919
m, n = A.shape
20-
A_banded = np.zeros((kl + ku + 1, n), dtype=A.dtype)
20+
if order == "C":
21+
A_banded = np.zeros((kl + ku + 1, n), dtype=A.dtype)
22+
else:
23+
A_banded = np.zeros((n, kl + ku + 1), dtype=A.dtype).T
2124

2225
for i, k in enumerate(range(ku, -kl - 1, -1)):
2326
if k >= 0:
@@ -35,7 +38,7 @@ def _dot_banded(A: np.ndarray, x: np.ndarray, kl: int, ku: int) -> np.ndarray:
3538
"""
3639
fn = linalg.get_blas_funcs("gbmv", (A, x))
3740
m, n = A.shape
38-
A_banded = A_to_banded(A, kl=kl, ku=ku)
41+
A_banded = A_to_banded(A, kl=kl, ku=ku, order="C")
3942

4043
return fn(m=m, n=n, kl=kl, ku=ku, alpha=1, a=A_banded, x=x)
4144

@@ -54,9 +57,7 @@ def dot_banded_impl(
5457
def impl(A: np.ndarray, x: np.ndarray, kl: int, ku: int) -> np.ndarray:
5558
m, n = A.shape
5659

57-
# TODO: Can we avoid this copy?
58-
A_banded = A_to_banded(A, kl=kl, ku=ku)
59-
A_banded = _copy_to_fortran_order(A_banded)
60+
A_banded = A_to_banded(A, kl=kl, ku=ku, order="F")
6061

6162
TRANS = val_to_int_ptr(ord("N"))
6263
M = val_to_int_ptr(m)

0 commit comments

Comments
 (0)