Skip to content

Commit c687856

Browse files
Simplify perf test
1 parent db5b23c commit c687856

File tree

1 file changed

+6
-7
lines changed

1 file changed

+6
-7
lines changed

tests/tensor/test_slinalg.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1091,15 +1091,14 @@ def test_banded_dot(A_shape, kl, ku):
10911091

10921092
@pytest.mark.parametrize("op", ["dot", "banded_dot"], ids=str)
10931093
@pytest.mark.parametrize(
1094-
"A_shape", [(10, 10), (100, 100), (1000, 1000)], ids=["10", "100", "1000"]
1095-
)
1096-
@pytest.mark.parametrize(
1097-
"kl, ku", [(1, 1), (0, 1), (2, 2)], ids=["tridiag", "upper-only", "banded"]
1094+
"A_shape",
1095+
[(10, 10), (100, 100), (1000, 1000), (10_000, 10_000)],
1096+
ids=["10", "100", "1000", "10_000"],
10981097
)
1099-
def test_banded_dot_perf(op, A_shape, kl, ku, benchmark):
1098+
def test_banded_dot_perf(op, A_shape, benchmark):
11001099
rng = np.random.default_rng()
11011100

1102-
A_val = _make_banded_A(rng.normal(size=A_shape), kl=kl, ku=ku)
1101+
A_val = _make_banded_A(rng.normal(size=A_shape), kl=1, ku=1)
11031102
b_val = rng.normal(size=(A_shape[-1],))
11041103

11051104
A = pt.tensor("A", shape=A_val.shape)
@@ -1108,7 +1107,7 @@ def test_banded_dot_perf(op, A_shape, kl, ku, benchmark):
11081107
if op == "dot":
11091108
f = pt.dot
11101109
elif op == "banded_dot":
1111-
f = functools.partial(banded_dot, lower_diags=kl, upper_diags=ku)
1110+
f = functools.partial(banded_dot, lower_diags=1, upper_diags=1)
11121111

11131112
res = f(A, b)
11141113
fn = function([A, b], res, trust_input=True)

0 commit comments

Comments
 (0)