@@ -1081,42 +1081,22 @@ def test_banded_dot(A_shape, kl, ku):
10811081 res = banded_dot (A , b , kl , ku )
10821082 res_2 = A @ b
10831083
1084- fn = function ([A , b ], [res , res_2 ])
1084+ fn = function ([A , b ], [res , res_2 ], trust_input = True )
10851085 assert any (isinstance (node .op , BandedDot ) for node in fn .maker .fgraph .apply_nodes )
10861086
10871087 x_val , x2_val = fn (A_val , b_val )
10881088
10891089 np .testing .assert_allclose (x_val , x2_val )
10901090
10911091
1092+ @pytest .mark .parametrize ("op" , ["dot" , "banded_dot" ], ids = str )
10921093@pytest .mark .parametrize (
10931094 "A_shape" , [(10 , 10 ), (100 , 100 ), (1000 , 1000 )], ids = ["10" , "100" , "1000" ]
10941095)
10951096@pytest .mark .parametrize (
10961097 "kl, ku" , [(1 , 1 ), (0 , 1 ), (2 , 2 )], ids = ["tridiag" , "upper-only" , "banded" ]
10971098)
1098- def test_banded_dot_perf (A_shape , kl , ku , benchmark ):
1099- rng = np .random .default_rng ()
1100-
1101- A_val = _make_banded_A (rng .normal (size = A_shape ), kl = kl , ku = ku )
1102- b_val = rng .normal (size = (A_shape [- 1 ],))
1103-
1104- A = pt .tensor ("A" , shape = A_val .shape , dtype = A_val .dtype )
1105- b = pt .tensor ("b" , shape = b_val .shape , dtype = b_val .dtype )
1106-
1107- res = banded_dot (A , b , kl , ku )
1108- fn = function ([A , b ], res , trust_input = True )
1109-
1110- benchmark (fn , A_val , b_val )
1111-
1112-
1113- @pytest .mark .parametrize (
1114- "A_shape" , [(10 , 10 ), (100 , 100 ), (1000 , 1000 )], ids = ["10" , "100" , "1000" ]
1115- )
1116- @pytest .mark .parametrize (
1117- "kl, ku" , [(1 , 1 ), (0 , 1 ), (2 , 2 )], ids = ["tridiag" , "upper-only" , "banded" ]
1118- )
1119- def test_dot_perf (A_shape , kl , ku , benchmark ):
1099+ def test_banded_dot_perf (op , A_shape , kl , ku , benchmark ):
11201100 rng = np .random .default_rng ()
11211101
11221102 A_val = _make_banded_A (rng .normal (size = A_shape ), kl = kl , ku = ku )
@@ -1125,7 +1105,12 @@ def test_dot_perf(A_shape, kl, ku, benchmark):
11251105 A = pt .tensor ("A" , shape = A_val .shape )
11261106 b = pt .tensor ("b" , shape = b_val .shape )
11271107
1128- res = A @ b
1129- fn = function ([A , b ], res )
1108+ if op == "dot" :
1109+ f = pt .dot
1110+ elif op == "banded_dot" :
1111+ f = functools .partial (banded_dot , lower_diags = kl , upper_diags = ku )
1112+
1113+ res = f (A , b )
1114+ fn = function ([A , b ], res , trust_input = True )
11301115
11311116 benchmark (fn , A_val , b_val )
0 commit comments