Skip to content

Commit 2b5c51d

Browse files
Test strides
1 parent 0505c57 commit 2b5c51d

File tree

2 files changed

+15
-11
lines changed

2 files changed

+15
-11
lines changed

tests/link/numba/test_slinalg.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -724,11 +724,15 @@ def test_lu_solve(b_func, b_shape: tuple[int, ...], trans: bool, overwrite_b: bo
724724
np.testing.assert_allclose(b_val_not_contig, b_val)
725725

726726

727-
def test_banded_dot():
727+
@pytest.mark.parametrize("stride", [1, 2, -1], ids=lambda x: f"stride={x}")
728+
def test_banded_dot(stride):
728729
rng = np.random.default_rng()
729730

730731
A_val = _make_banded_A(rng.normal(size=(10, 10)), kl=1, ku=1).astype(config.floatX)
731-
x_val = rng.normal(size=(10,)).astype(config.floatX)
732+
733+
x_shape = (10 * abs(stride),)
734+
x_val = rng.normal(size=x_shape).astype(config.floatX)
735+
x_val = x_val[::stride]
732736

733737
A = pt.tensor("A", shape=A_val.shape, dtype=A_val.dtype)
734738
x = pt.tensor("x", shape=x_val.shape, dtype=x_val.dtype)

tests/tensor/test_slinalg.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1061,20 +1061,20 @@ def _make_banded_A(A, kl, ku):
10611061
return sum(np.diag(d, k=k) for k, d in zip(diag_idxs, diags))
10621062

10631063

1064-
@pytest.mark.parametrize(
1065-
"A_shape",
1066-
[
1067-
(10, 10),
1068-
],
1069-
)
10701064
@pytest.mark.parametrize(
10711065
"kl, ku", [(1, 1), (0, 1), (2, 2)], ids=["tridiag", "upper-only", "banded"]
10721066
)
1073-
def test_banded_dot(A_shape, kl, ku):
1067+
@pytest.mark.parametrize("stride", [1, 2, -1], ids=lambda x: f"stride={x}")
1068+
def test_banded_dot(kl, ku, stride):
10741069
rng = np.random.default_rng()
10751070

1076-
A_val = _make_banded_A(rng.normal(size=A_shape), kl=kl, ku=ku).astype(config.floatX)
1077-
x_val = rng.normal(size=(A_shape[-1],)).astype(config.floatX)
1071+
size = 10
1072+
1073+
A_val = _make_banded_A(rng.normal(size=(size, size)), kl=kl, ku=ku).astype(
1074+
config.floatX
1075+
)
1076+
x_val = rng.normal(size=(size * abs(stride),)).astype(config.floatX)
1077+
x_val = x_val[::stride]
10781078

10791079
A = pt.tensor("A", shape=A_val.shape, dtype=A_val.dtype)
10801080
x = pt.tensor("x", shape=x_val.shape, dtype=x_val.dtype)

0 commit comments

Comments
 (0)