44import numpy as np
55import openblas_wrap as ow
66
7+ dtype_map = {
8+ 's' : np .float32 ,
9+ 'd' : np .float64 ,
10+ 'c' : np .complex64 ,
11+ 'z' : np .complex128 ,
12+ 'dz' : np .complex128 ,
13+ }
14+
715
816# ### BLAS level 1 ###
917
@@ -24,7 +32,9 @@ class Nrm2:
2432
2533 def setup (self , n , variant ):
2634 rndm = np .random .RandomState (1234 )
27- self .x = rndm .uniform (size = (n ,)).astype (float )
35+ dtyp = dtype_map [variant ]
36+
37+ self .x = rndm .uniform (size = (n ,)).astype (dtyp )
2838 self .nrm2 = ow .get_func ('nrm2' , variant )
2939
3040 def time_dnrm2 (self , n , variant ):
@@ -46,8 +56,10 @@ class DDot:
4656
4757 def setup (self , n ):
4858 rndm = np .random .RandomState (1234 )
49- self .x = np .array (rndm .uniform (size = (n ,)), dtype = float )
50- self .y = np .array (rndm .uniform (size = (n ,)), dtype = float )
59+ dtyp = float
60+
61+ self .x = np .array (rndm .uniform (size = (n ,)), dtype = dtyp )
62+ self .y = np .array (rndm .uniform (size = (n ,)), dtype = dtyp )
5163 self .func = ow .get_func ('dot' , 'd' )
5264
5365 def time_ddot (self , n ):
@@ -70,8 +82,10 @@ class Daxpy:
7082
7183 def setup (self , n , variant ):
7284 rndm = np .random .RandomState (1234 )
73- self .x = np .array (rndm .uniform (size = (n ,)), dtype = float )
74- self .y = np .array (rndm .uniform (size = (n ,)), dtype = float )
85+ dtyp = dtype_map [variant ]
86+
87+ self .x = np .array (rndm .uniform (size = (n ,)), dtype = dtyp )
88+ self .y = np .array (rndm .uniform (size = (n ,)), dtype = dtyp )
7589 self .axpy = ow .get_func ('axpy' , variant )
7690
7791 def time_daxpy (self , n , variant ):
@@ -97,9 +111,11 @@ class Dgemm:
97111
98112 def setup (self , n , variant ):
99113 rndm = np .random .RandomState (1234 )
100- self .a = np .array (rndm .uniform (size = (n , n )), dtype = float , order = 'F' )
101- self .b = np .array (rndm .uniform (size = (n , n )), dtype = float , order = 'F' )
102- self .c = np .empty ((n , n ), dtype = float , order = 'F' )
114+ dtyp = dtype_map [variant ]
115+
116+ self .a = np .array (rndm .uniform (size = (n , n )), dtype = dtyp , order = 'F' )
117+ self .b = np .array (rndm .uniform (size = (n , n )), dtype = dtyp , order = 'F' )
118+ self .c = np .empty ((n , n ), dtype = dtyp , order = 'F' )
103119 self .func = ow .get_func ('gemm' , variant )
104120
105121 def time_dgemm (self , n , variant ):
@@ -122,8 +138,10 @@ class DSyrk:
122138
123139 def setup (self , n , variant ):
124140 rndm = np .random .RandomState (1234 )
125- self .a = np .array (rndm .uniform (size = (n , n )), dtype = float , order = 'F' )
126- self .c = np .empty ((n , n ), dtype = float , order = 'F' )
141+ dtyp = dtype_map [variant ]
142+
143+ self .a = np .array (rndm .uniform (size = (n , n )), dtype = dtyp , order = 'F' )
144+ self .c = np .empty ((n , n ), dtype = dtyp , order = 'F' )
127145 self .func = ow .get_func ('syrk' , variant )
128146
129147 def time_dsyrk (self , n , variant ):
@@ -148,9 +166,11 @@ class Dgesv:
148166
149167 def setup (self , n , variant ):
150168 rndm = np .random .RandomState (1234 )
151- self .a = (np .array (rndm .uniform (size = (n , n )), dtype = float , order = 'F' ) +
152- np .eye (n , order = 'F' ))
153- self .b = np .array (rndm .uniform (size = (n , 1 )), order = 'F' )
169+ dtyp = dtype_map [variant ]
170+
171+ self .a = (np .array (rndm .uniform (size = (n , n )), dtype = dtyp , order = 'F' ) +
172+ np .eye (n , dtype = dtyp , order = 'F' ))
173+ self .b = np .array (rndm .uniform (size = (n , 1 )), dtype = dtyp , order = 'F' )
154174 self .func = ow .get_func ('gesv' , variant )
155175
156176 def time_dgesv (self , n , variant ):
@@ -181,7 +201,9 @@ def setup(self, mn, variant):
181201 m , n = (int (x ) for x in mn .split ("," ))
182202
183203 rndm = np .random .RandomState (1234 )
184- a = np .array (rndm .uniform (size = (m , n )), dtype = float , order = 'F' )
204+ dtyp = dtype_map [variant ]
205+
206+ a = np .array (rndm .uniform (size = (m , n )), dtype = dtyp , order = 'F' )
185207
186208 gesdd_lwork = ow .get_func ('gesdd_lwork' , variant )
187209
@@ -212,8 +234,10 @@ class Dsyev:
212234
213235 def setup (self , n , variant ):
214236 rndm = np .random .RandomState (1234 )
237+ dtyp = dtype_map [variant ]
238+
215239 a = rndm .uniform (size = (n , n ))
216- a = np .asarray (a + a .T , dtype = float , order = 'F' )
240+ a = np .asarray (a + a .T , dtype = dtyp , order = 'F' )
217241 a_ = a .copy ()
218242
219243 syev_lwork = ow .get_func ('syev_lwork' , variant )
0 commit comments