Skip to content

Commit fa998f2

Browse files
committed
careful with array dtypes per BLAS variant
1 parent 986b154 commit fa998f2

File tree

1 file changed

+39
-15
lines changed

1 file changed

+39
-15
lines changed

benchmarks/benchmarks.py

Lines changed: 39 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,14 @@
44
import numpy as np
55
import openblas_wrap as ow
66

7+
dtype_map = {
8+
's': np.float32,
9+
'd': np.float64,
10+
'c': np.complex64,
11+
'z': np.complex128,
12+
'dz': np.complex128,
13+
}
14+
715

816
# ### BLAS level 1 ###
917

@@ -24,7 +32,9 @@ class Nrm2:
2432

2533
def setup(self, n, variant):
2634
rndm = np.random.RandomState(1234)
27-
self.x = rndm.uniform(size=(n,)).astype(float)
35+
dtyp = dtype_map[variant]
36+
37+
self.x = rndm.uniform(size=(n,)).astype(dtyp)
2838
self.nrm2 = ow.get_func('nrm2', variant)
2939

3040
def time_dnrm2(self, n, variant):
@@ -46,8 +56,10 @@ class DDot:
4656

4757
def setup(self, n):
4858
rndm = np.random.RandomState(1234)
49-
self.x = np.array(rndm.uniform(size=(n,)), dtype=float)
50-
self.y = np.array(rndm.uniform(size=(n,)), dtype=float)
59+
dtyp = float
60+
61+
self.x = np.array(rndm.uniform(size=(n,)), dtype=dtyp)
62+
self.y = np.array(rndm.uniform(size=(n,)), dtype=dtyp)
5163
self.func = ow.get_func('dot', 'd')
5264

5365
def time_ddot(self, n):
@@ -70,8 +82,10 @@ class Daxpy:
7082

7183
def setup(self, n, variant):
7284
rndm = np.random.RandomState(1234)
73-
self.x = np.array(rndm.uniform(size=(n,)), dtype=float)
74-
self.y = np.array(rndm.uniform(size=(n,)), dtype=float)
85+
dtyp = dtype_map[variant]
86+
87+
self.x = np.array(rndm.uniform(size=(n,)), dtype=dtyp)
88+
self.y = np.array(rndm.uniform(size=(n,)), dtype=dtyp)
7589
self.axpy = ow.get_func('axpy', variant)
7690

7791
def time_daxpy(self, n, variant):
@@ -97,9 +111,11 @@ class Dgemm:
97111

98112
def setup(self, n, variant):
99113
rndm = np.random.RandomState(1234)
100-
self.a = np.array(rndm.uniform(size=(n, n)), dtype=float, order='F')
101-
self.b = np.array(rndm.uniform(size=(n, n)), dtype=float, order='F')
102-
self.c = np.empty((n, n), dtype=float, order='F')
114+
dtyp = dtype_map[variant]
115+
116+
self.a = np.array(rndm.uniform(size=(n, n)), dtype=dtyp, order='F')
117+
self.b = np.array(rndm.uniform(size=(n, n)), dtype=dtyp, order='F')
118+
self.c = np.empty((n, n), dtype=dtyp, order='F')
103119
self.func = ow.get_func('gemm', variant)
104120

105121
def time_dgemm(self, n, variant):
@@ -122,8 +138,10 @@ class DSyrk:
122138

123139
def setup(self, n, variant):
124140
rndm = np.random.RandomState(1234)
125-
self.a = np.array(rndm.uniform(size=(n, n)), dtype=float, order='F')
126-
self.c = np.empty((n, n), dtype=float, order='F')
141+
dtyp = dtype_map[variant]
142+
143+
self.a = np.array(rndm.uniform(size=(n, n)), dtype=dtyp, order='F')
144+
self.c = np.empty((n, n), dtype=dtyp, order='F')
127145
self.func = ow.get_func('syrk', variant)
128146

129147
def time_dsyrk(self, n, variant):
@@ -148,9 +166,11 @@ class Dgesv:
148166

149167
def setup(self, n, variant):
150168
rndm = np.random.RandomState(1234)
151-
self.a = (np.array(rndm.uniform(size=(n, n)), dtype=float, order='F') +
152-
np.eye(n, order='F'))
153-
self.b = np.array(rndm.uniform(size=(n, 1)), order='F')
169+
dtyp = dtype_map[variant]
170+
171+
self.a = (np.array(rndm.uniform(size=(n, n)), dtype=dtyp, order='F') +
172+
np.eye(n, dtype=dtyp, order='F'))
173+
self.b = np.array(rndm.uniform(size=(n, 1)), dtype=dtyp, order='F')
154174
self.func = ow.get_func('gesv', variant)
155175

156176
def time_dgesv(self, n, variant):
@@ -181,7 +201,9 @@ def setup(self, mn, variant):
181201
m, n = (int(x) for x in mn.split(","))
182202

183203
rndm = np.random.RandomState(1234)
184-
a = np.array(rndm.uniform(size=(m, n)), dtype=float, order='F')
204+
dtyp = dtype_map[variant]
205+
206+
a = np.array(rndm.uniform(size=(m, n)), dtype=dtyp, order='F')
185207

186208
gesdd_lwork = ow.get_func('gesdd_lwork', variant)
187209

@@ -212,8 +234,10 @@ class Dsyev:
212234

213235
def setup(self, n, variant):
214236
rndm = np.random.RandomState(1234)
237+
dtyp = dtype_map[variant]
238+
215239
a = rndm.uniform(size=(n, n))
216-
a = np.asarray(a + a.T, dtype=float, order='F')
240+
a = np.asarray(a + a.T, dtype=dtyp, order='F')
217241
a_ = a.copy()
218242

219243
syev_lwork = ow.get_func('syev_lwork', variant)

0 commit comments

Comments
 (0)