Skip to content

Commit 297612b

Browse files
authored
Add files via upload
1 parent 63561e0 commit 297612b

File tree

1 file changed

+37
-0
lines changed

1 file changed

+37
-0
lines changed

misc/benchmark_mm_trtrs_inv.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
"""Benchmarking matrix multiplication, back substitution and inverse
2+
"""
3+
import torch
4+
import time
5+
import random
6+
import statistics
7+
8+
n = 1024
9+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
10+
results = []#check whether NaN shows up
11+
times_mm, times_trtrs, times_inv = [], [], []
12+
with torch.no_grad():
13+
while min([len(times_mm), len(times_trtrs), len(times_inv)])<3:
14+
A = torch.randn(n, n, device=device) + 10*torch.eye(n, device=device)
15+
b = torch.randn(n, n, device=device)
16+
j = random.randint(0, 2)#call mm, trtrs, inverse in random order
17+
if j==0:
18+
t0 = time.time()
19+
x = A.mm(b)
20+
results.append(x[0,0])
21+
times_mm.append(time.time() - t0)
22+
elif j==1:
23+
t0 = time.time()
24+
x = torch.trtrs(b, A)[0]#just take triangular part of A
25+
results.append(x[0,0])
26+
times_trtrs.append(time.time() - t0)
27+
else:
28+
t0 = time.time()
29+
x = torch.inverse(A)
30+
results.append(x[0,0])
31+
times_inv.append(time.time() - t0)
32+
33+
print('Median Time in ms:')
34+
print('Multiplication {}; BackSubstitution {}; Inversion {}'.format(
35+
1000*statistics.median(times_mm),
36+
1000*statistics.median(times_trtrs),
37+
1000*statistics.median(times_inv)))

0 commit comments

Comments
 (0)