Skip to content

Commit 37129b3

Browse files
committed
BENCH: linalg: add a benchmark for batched solve for several structures
1 parent cf214ab commit 37129b3

File tree

1 file changed

+30
-0
lines changed

1 file changed

+30
-0
lines changed

benchmarks/benchmarks/linalg.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,36 @@ def time_svd(self, size, contig, module):
102102
)
103103

104104

105+
class BatchedSolveBench(Benchmark):
106+
params = [
107+
[(100, 10, 10), (100, 20, 20), (100, 100)],
108+
["gen", "pos", "sym"],
109+
["scipy", "numpy"]
110+
]
111+
param_names = ["shape", "structure" ,"module"]
112+
113+
def setup(self, shape, structure, module):
114+
a = random(shape)
115+
# larger diagonal ensures non-singularity:
116+
for i in range(shape[-1]):
117+
a[..., i, i] = 10*(.1+a[..., i, i])
118+
119+
if structure == "pos":
120+
self.a = a @ a.mT
121+
elif structure == "sym":
122+
self.a = a + a.mT
123+
else:
124+
self.a = a
125+
126+
self.b = random([a.shape[-1]])
127+
128+
def time_solve(self, shape, structure, module):
129+
if module == 'numpy':
130+
nl.solve(self.a, self.b)
131+
else:
132+
sl.solve(self.a, self.b, assume_a=structure)
133+
134+
105135
class Norm(Benchmark):
106136
params = [
107137
[(20, 20), (100, 100), (1000, 1000), (20, 1000), (1000, 20)],

0 commit comments

Comments
 (0)