@@ -1091,15 +1091,14 @@ def test_banded_dot(A_shape, kl, ku):
1091
1091
1092
1092
@pytest .mark .parametrize ("op" , ["dot" , "banded_dot" ], ids = str )
1093
1093
@pytest .mark .parametrize (
1094
- "A_shape" , [(10 , 10 ), (100 , 100 ), (1000 , 1000 )], ids = ["10" , "100" , "1000" ]
1095
- )
1096
- @pytest .mark .parametrize (
1097
- "kl, ku" , [(1 , 1 ), (0 , 1 ), (2 , 2 )], ids = ["tridiag" , "upper-only" , "banded" ]
1094
+ "A_shape" ,
1095
+ [(10 , 10 ), (100 , 100 ), (1000 , 1000 ), (10_000 , 10_000 )],
1096
+ ids = ["10" , "100" , "1000" , "10_000" ],
1098
1097
)
1099
- def test_banded_dot_perf (op , A_shape , kl , ku , benchmark ):
1098
+ def test_banded_dot_perf (op , A_shape , benchmark ):
1100
1099
rng = np .random .default_rng ()
1101
1100
1102
- A_val = _make_banded_A (rng .normal (size = A_shape ), kl = kl , ku = ku )
1101
+ A_val = _make_banded_A (rng .normal (size = A_shape ), kl = 1 , ku = 1 )
1103
1102
b_val = rng .normal (size = (A_shape [- 1 ],))
1104
1103
1105
1104
A = pt .tensor ("A" , shape = A_val .shape )
@@ -1108,7 +1107,7 @@ def test_banded_dot_perf(op, A_shape, kl, ku, benchmark):
1108
1107
if op == "dot" :
1109
1108
f = pt .dot
1110
1109
elif op == "banded_dot" :
1111
- f = functools .partial (banded_dot , lower_diags = kl , upper_diags = ku )
1110
+ f = functools .partial (banded_dot , lower_diags = 1 , upper_diags = 1 )
1112
1111
1113
1112
res = f (A , b )
1114
1113
fn = function ([A , b ], res , trust_input = True )
0 commit comments