1414
1515from kernel_tuner .backends .backend import CompilerBackend
1616from kernel_tuner .observers .c import CRuntimeObserver
17- from kernel_tuner .util import get_temp_filename , delete_temp_file , write_file , SkippableFailure
18-
19- dtype_map = {"int8" : C .c_int8 ,
20- "int16" : C .c_int16 ,
21- "int32" : C .c_int32 ,
22- "int64" : C .c_int64 ,
23- "uint8" : C .c_uint8 ,
24- "uint16" : C .c_uint16 ,
25- "uint32" : C .c_uint32 ,
26- "uint64" : C .c_uint64 ,
27- "float32" : C .c_float ,
28- "float64" : C .c_double }
17+ from kernel_tuner .util import (
18+ get_temp_filename ,
19+ delete_temp_file ,
20+ write_file ,
21+ SkippableFailure ,
22+ )
23+
24+ dtype_map = {
25+ "int8" : C .c_int8 ,
26+ "int16" : C .c_int16 ,
27+ "int32" : C .c_int32 ,
28+ "int64" : C .c_int64 ,
29+ "uint8" : C .c_uint8 ,
30+ "uint16" : C .c_uint16 ,
31+ "uint32" : C .c_uint32 ,
32+ "uint64" : C .c_uint64 ,
33+ "float32" : C .c_float ,
34+ "float64" : C .c_double ,
35+ }
2936
3037# This represents an individual kernel argument.
3138# It contains a numpy object (ndarray or number) and a ctypes object with a copy
@@ -58,7 +65,7 @@ def __init__(self, iterations=7, compiler_options=None, compiler=None):
5865 except OSError as e :
5966 raise e
6067
61- #check if nvcc is available
68+ # check if nvcc is available
6269 self .nvcc_available = False
6370 try :
6471 nvcc_version = str (subprocess .check_output (["nvcc" , "--version" ]))
@@ -68,7 +75,7 @@ def __init__(self, iterations=7, compiler_options=None, compiler=None):
6875 if e .errno != errno .ENOENT :
6976 raise e
7077
71- #environment info
78+ # environment info
7279 env = dict ()
7380 env ["CC Version" ] = cc_version
7481 if self .nvcc_available :
@@ -89,19 +96,21 @@ def ready_argument_list(self, arguments):
8996 :returns: A list of arguments that can be passed to the C function.
9097 :rtype: list(Argument)
9198 """
92- ctype_args = [ None for _ in arguments ]
99+ ctype_args = [None for _ in arguments ]
93100
94101 for i , arg in enumerate (arguments ):
95102 if not isinstance (arg , (np .ndarray , np .number )):
96- raise TypeError ("Argument is not numpy ndarray or numpy scalar %s" % type (arg ))
103+ raise TypeError (
104+ "Argument is not numpy ndarray or numpy scalar %s" % type (arg )
105+ )
97106 dtype_str = str (arg .dtype )
98107 if isinstance (arg , np .ndarray ):
99108 if dtype_str in dtype_map .keys ():
100109 # In numpy <= 1.15, ndarray.ctypes.data_as does not itself keep a reference
101110 # to its underlying array, so we need to store a reference to arg.copy()
102111 # in the Argument object manually to avoid it being deleted.
103112 # (This changed in numpy > 1.15.)
104- #data_ctypes = data.ctypes.data_as(C.POINTER(dtype_map[dtype_str]))
113+ # data_ctypes = data.ctypes.data_as(C.POINTER(dtype_map[dtype_str]))
105114 data_ctypes = arg .ctypes .data_as (C .POINTER (dtype_map [dtype_str ]))
106115 else :
107116 raise TypeError ("unknown dtype for ndarray" )
@@ -120,7 +129,7 @@ def compile(self, kernel_instance):
120129 :returns: An ctypes function that can be called directly.
121130 :rtype: ctypes._FuncPtr
122131 """
123- logging .debug (' compiling ' + kernel_instance .name )
132+ logging .debug (" compiling " + kernel_instance .name )
124133
125134 kernel_string = kernel_instance .kernel_string
126135 kernel_name = kernel_instance .name
@@ -130,9 +139,9 @@ def compile(self, kernel_instance):
130139
131140 compiler_options = ["-fPIC" ]
132141
133- #detect openmp
142+ # detect openmp
134143 if "#include <omp.h>" in kernel_string or "use omp_lib" in kernel_string :
135- logging .debug (' set using_openmp to true' )
144+ logging .debug (" set using_openmp to true" )
136145 self .using_openmp = True
137146 if self .compiler == "pgfortran" :
138147 compiler_options .append ("-mp" )
@@ -142,20 +151,28 @@ def compile(self, kernel_instance):
142151 else :
143152 compiler_options .append ("-fopenmp" )
144153
145- #if filename is known, use that one
154+ # if filename is known, use that one
146155 suffix = kernel_instance .kernel_source .get_user_suffix ()
147156
148- #if code contains device code, suffix .cu is required
157+ # if code contains device code, suffix .cu is required
149158 device_code_signals = ["__global" , "__syncthreads()" , "threadIdx" ]
150159 if any ([snippet in kernel_string for snippet in device_code_signals ]):
151160 suffix = ".cu"
152161
153- #detect whether to use nvcc as default instead of g++, may overrule an explicitly passed g++
154- if ((suffix == ".cu" ) or ("#include <cuda" in kernel_string ) or ("cudaMemcpy" in kernel_string )) and self .compiler == "g++" and self .nvcc_available :
162+ # detect whether to use nvcc as default instead of g++, may overrule an explicitly passed g++
163+ if (
164+ (
165+ (suffix == ".cu" )
166+ or ("#include <cuda" in kernel_string )
167+ or ("cudaMemcpy" in kernel_string )
168+ )
169+ and self .compiler == "g++"
170+ and self .nvcc_available
171+ ):
155172 self .compiler = "nvcc"
156173
157174 if suffix is None :
158- #select right suffix based on compiler
175+ # select right suffix based on compiler
159176 suffix = ".cc"
160177
161178 if self .compiler in ["gfortran" , "pgfortran" , "ftn" , "ifort" ]:
@@ -164,27 +181,27 @@ def compile(self, kernel_instance):
164181 if self .compiler == "nvcc" :
165182 compiler_options = ["-Xcompiler=" + c for c in compiler_options ]
166183
167- #this basically checks if we aren't compiling Fortran
168- #at the moment any C, C++, or CUDA code is assumed to use extern "C" linkage
169- if ".c" in suffix and " extern \" C \" " not in kernel_string :
170- kernel_string = " extern \" C \ " {\n " + kernel_string + "\n }"
184+ # this basically checks if we aren't compiling Fortran
185+ # at the moment any C, C++, or CUDA code is assumed to use extern "C" linkage
186+ if ".c" in suffix and ' extern "C"' not in kernel_string :
187+ kernel_string = ' extern "C " {\n ' + kernel_string + "\n }"
171188
172- #copy user specified compiler options to current list
189+ # copy user specified compiler options to current list
173190 if self .compiler_options :
174191 compiler_options += self .compiler_options
175192
176193 lib_args = []
177194 if "CL/cl.h" in kernel_string :
178195 lib_args = ["-lOpenCL" ]
179196
180- logging .debug (' using compiler ' + self .compiler )
181- logging .debug (' compiler_options ' + " " .join (compiler_options ))
182- logging .debug (' lib_args ' + " " .join (lib_args ))
197+ logging .debug (" using compiler " + self .compiler )
198+ logging .debug (" compiler_options " + " " .join (compiler_options ))
199+ logging .debug (" lib_args " + " " .join (lib_args ))
183200
184201 source_file = get_temp_filename (suffix = suffix )
185202 filename = "." .join (source_file .split ("." )[:- 1 ])
186203
187- #detect Fortran modules
204+ # detect Fortran modules
188205 match = re .search (r"\s*module\s+([a-zA-Z_]*)" , kernel_string )
189206 if match :
190207 if self .compiler == "gfortran" :
@@ -194,7 +211,7 @@ def compile(self, kernel_instance):
194211 elif self .compiler == "pgfortran" :
195212 kernel_name = match .group (1 ) + "_" + kernel_name + "_"
196213 else :
197- #for functions outside of modules
214+ # for functions outside of modules
198215 if self .compiler in ["gfortran" , "ftn" , "ifort" , "pgfortran" ]:
199216 kernel_name = kernel_name + "_"
200217
@@ -205,44 +222,52 @@ def compile(self, kernel_instance):
205222 if platform .system () == "Darwin" :
206223 lib_extension = ".dylib"
207224
208- subprocess .check_call ([self .compiler , "-c" , source_file ] + compiler_options + ["-o" , filename + ".o" ])
209- subprocess .check_call ([self .compiler , filename + ".o" ] + compiler_options + ["-shared" , "-o" , filename + lib_extension ] + lib_args )
210-
211- self .lib = np .ctypeslib .load_library (filename , '.' )
225+ subprocess .check_call (
226+ [self .compiler , "-c" , source_file ]
227+ + compiler_options
228+ + ["-o" , filename + ".o" ]
229+ )
230+ subprocess .check_call (
231+ [self .compiler , filename + ".o" ]
232+ + compiler_options
233+ + ["-shared" , "-o" , filename + lib_extension ]
234+ + lib_args
235+ )
236+
237+ self .lib = np .ctypeslib .load_library (filename , "." )
212238 func = getattr (self .lib , kernel_name )
213239 func .restype = C .c_float
214240
215241 finally :
216242 delete_temp_file (source_file )
217- delete_temp_file (filename + ".o" )
218- delete_temp_file (filename + ".so" )
219- delete_temp_file (filename + ".dylib" )
243+ delete_temp_file (filename + ".o" )
244+ delete_temp_file (filename + ".so" )
245+ delete_temp_file (filename + ".dylib" )
220246
221247 return func
222248
223-
224249 def start_event (self ):
225- """ Records the event that marks the start of a measurement
250+ """Records the event that marks the start of a measurement
226251
227- C backend does not use events """
252+ C backend does not use events"""
228253 pass
229254
230255 def stop_event (self ):
231- """ Records the event that marks the end of a measurement
256+ """Records the event that marks the end of a measurement
232257
233- C backend does not use events """
258+ C backend does not use events"""
234259 pass
235260
236261 def kernel_finished (self ):
237- """ Returns True if the kernel has finished, False otherwise
262+ """Returns True if the kernel has finished, False otherwise
238263
239- C backend does not support asynchronous launches """
264+ C backend does not support asynchronous launches"""
240265 return True
241266
242267 def synchronize (self ):
243- """ Halts execution until device has finished its tasks
268+ """Halts execution until device has finished its tasks
244269
245- C backend does not support asynchronous launches """
270+ C backend does not support asynchronous launches"""
246271 pass
247272
248273 def run_kernel (self , func , c_args , threads , grid ):
@@ -312,11 +337,11 @@ def memcpy_htod(self, dest, src):
312337 dest .numpy [:] = src
313338
314339 def cleanup_lib (self ):
315- """ unload the previously loaded shared library """
340+ """unload the previously loaded shared library"""
316341 if not self .using_openmp :
317- #this if statement is necessary because shared libraries that use
318- #OpenMP will core dump when unloaded, this is a well-known issue with OpenMP
319- logging .debug (' unloading shared library' )
342+ # this if statement is necessary because shared libraries that use
343+ # OpenMP will core dump when unloaded, this is a well-known issue with OpenMP
344+ logging .debug (" unloading shared library" )
320345 _ctypes .dlclose (self .lib ._handle )
321346
322347 units = {}
0 commit comments