Skip to content

Commit 687877c

Browse files
set INCX by strides
1 parent 905fc7c commit 687877c

File tree

2 files changed

+13
-1
lines changed

2 files changed

+13
-1
lines changed

pytensor/link/numba/dispatch/linalg/dot/banded.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def impl(A: np.ndarray, x: np.ndarray, kl: int, ku: int) -> np.ndarray:
6969
KU = val_to_int_ptr(ku)
7070

7171
ALPHA = np.array(1.0, dtype=dtype)
72-
INCX = val_to_int_ptr(1)
72+
INCX = val_to_int_ptr(x.strides[0] // x.itemsize)
7373
BETA = np.array(0.0, dtype=dtype)
7474
Y = np.empty(m, dtype=dtype)
7575
INCY = val_to_int_ptr(1)

tests/link/numba/test_slinalg.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -743,3 +743,15 @@ def test_banded_dot():
743743
numba_mode=numba_inplace_mode,
744744
eval_obj_mode=False,
745745
)
746+
747+
# Test non-contiguous x input
748+
x_val = rng.normal(size=(20,))[::2]
749+
750+
compare_numba_and_py(
751+
[A, x],
752+
output,
753+
test_inputs=[A_val, x_val],
754+
inplace=True,
755+
numba_mode=numba_inplace_mode,
756+
eval_obj_mode=False,
757+
)

0 commit comments

Comments
 (0)