Skip to content

Commit aa01d8a

Browse files
committed
Backends reformatted with black.
1 parent ca2bbdf commit aa01d8a

File tree

5 files changed

+270
-188
lines changed

5 files changed

+270
-188
lines changed

kernel_tuner/backends/c.py

Lines changed: 80 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,25 @@
1414

1515
from kernel_tuner.backends.backend import CompilerBackend
1616
from kernel_tuner.observers.c import CRuntimeObserver
17-
from kernel_tuner.util import get_temp_filename, delete_temp_file, write_file, SkippableFailure
18-
19-
dtype_map = {"int8": C.c_int8,
20-
"int16": C.c_int16,
21-
"int32": C.c_int32,
22-
"int64": C.c_int64,
23-
"uint8": C.c_uint8,
24-
"uint16": C.c_uint16,
25-
"uint32": C.c_uint32,
26-
"uint64": C.c_uint64,
27-
"float32": C.c_float,
28-
"float64": C.c_double}
17+
from kernel_tuner.util import (
18+
get_temp_filename,
19+
delete_temp_file,
20+
write_file,
21+
SkippableFailure,
22+
)
23+
24+
dtype_map = {
25+
"int8": C.c_int8,
26+
"int16": C.c_int16,
27+
"int32": C.c_int32,
28+
"int64": C.c_int64,
29+
"uint8": C.c_uint8,
30+
"uint16": C.c_uint16,
31+
"uint32": C.c_uint32,
32+
"uint64": C.c_uint64,
33+
"float32": C.c_float,
34+
"float64": C.c_double,
35+
}
2936

3037
# This represents an individual kernel argument.
3138
# It contains a numpy object (ndarray or number) and a ctypes object with a copy
@@ -58,7 +65,7 @@ def __init__(self, iterations=7, compiler_options=None, compiler=None):
5865
except OSError as e:
5966
raise e
6067

61-
#check if nvcc is available
68+
# check if nvcc is available
6269
self.nvcc_available = False
6370
try:
6471
nvcc_version = str(subprocess.check_output(["nvcc", "--version"]))
@@ -68,7 +75,7 @@ def __init__(self, iterations=7, compiler_options=None, compiler=None):
6875
if e.errno != errno.ENOENT:
6976
raise e
7077

71-
#environment info
78+
# environment info
7279
env = dict()
7380
env["CC Version"] = cc_version
7481
if self.nvcc_available:
@@ -89,19 +96,21 @@ def ready_argument_list(self, arguments):
8996
:returns: A list of arguments that can be passed to the C function.
9097
:rtype: list(Argument)
9198
"""
92-
ctype_args = [ None for _ in arguments]
99+
ctype_args = [None for _ in arguments]
93100

94101
for i, arg in enumerate(arguments):
95102
if not isinstance(arg, (np.ndarray, np.number)):
96-
raise TypeError("Argument is not numpy ndarray or numpy scalar %s" % type(arg))
103+
raise TypeError(
104+
"Argument is not numpy ndarray or numpy scalar %s" % type(arg)
105+
)
97106
dtype_str = str(arg.dtype)
98107
if isinstance(arg, np.ndarray):
99108
if dtype_str in dtype_map.keys():
100109
# In numpy <= 1.15, ndarray.ctypes.data_as does not itself keep a reference
101110
# to its underlying array, so we need to store a reference to arg.copy()
102111
# in the Argument object manually to avoid it being deleted.
103112
# (This changed in numpy > 1.15.)
104-
#data_ctypes = data.ctypes.data_as(C.POINTER(dtype_map[dtype_str]))
113+
# data_ctypes = data.ctypes.data_as(C.POINTER(dtype_map[dtype_str]))
105114
data_ctypes = arg.ctypes.data_as(C.POINTER(dtype_map[dtype_str]))
106115
else:
107116
raise TypeError("unknown dtype for ndarray")
@@ -120,7 +129,7 @@ def compile(self, kernel_instance):
120129
:returns: An ctypes function that can be called directly.
121130
:rtype: ctypes._FuncPtr
122131
"""
123-
logging.debug('compiling ' + kernel_instance.name)
132+
logging.debug("compiling " + kernel_instance.name)
124133

125134
kernel_string = kernel_instance.kernel_string
126135
kernel_name = kernel_instance.name
@@ -130,9 +139,9 @@ def compile(self, kernel_instance):
130139

131140
compiler_options = ["-fPIC"]
132141

133-
#detect openmp
142+
# detect openmp
134143
if "#include <omp.h>" in kernel_string or "use omp_lib" in kernel_string:
135-
logging.debug('set using_openmp to true')
144+
logging.debug("set using_openmp to true")
136145
self.using_openmp = True
137146
if self.compiler == "pgfortran":
138147
compiler_options.append("-mp")
@@ -142,20 +151,28 @@ def compile(self, kernel_instance):
142151
else:
143152
compiler_options.append("-fopenmp")
144153

145-
#if filename is known, use that one
154+
# if filename is known, use that one
146155
suffix = kernel_instance.kernel_source.get_user_suffix()
147156

148-
#if code contains device code, suffix .cu is required
157+
# if code contains device code, suffix .cu is required
149158
device_code_signals = ["__global", "__syncthreads()", "threadIdx"]
150159
if any([snippet in kernel_string for snippet in device_code_signals]):
151160
suffix = ".cu"
152161

153-
#detect whether to use nvcc as default instead of g++, may overrule an explicitly passed g++
154-
if ((suffix == ".cu") or ("#include <cuda" in kernel_string) or ("cudaMemcpy" in kernel_string)) and self.compiler == "g++" and self.nvcc_available:
162+
# detect whether to use nvcc as default instead of g++, may overrule an explicitly passed g++
163+
if (
164+
(
165+
(suffix == ".cu")
166+
or ("#include <cuda" in kernel_string)
167+
or ("cudaMemcpy" in kernel_string)
168+
)
169+
and self.compiler == "g++"
170+
and self.nvcc_available
171+
):
155172
self.compiler = "nvcc"
156173

157174
if suffix is None:
158-
#select right suffix based on compiler
175+
# select right suffix based on compiler
159176
suffix = ".cc"
160177

161178
if self.compiler in ["gfortran", "pgfortran", "ftn", "ifort"]:
@@ -164,27 +181,27 @@ def compile(self, kernel_instance):
164181
if self.compiler == "nvcc":
165182
compiler_options = ["-Xcompiler=" + c for c in compiler_options]
166183

167-
#this basically checks if we aren't compiling Fortran
168-
#at the moment any C, C++, or CUDA code is assumed to use extern "C" linkage
169-
if ".c" in suffix and "extern \"C\"" not in kernel_string:
170-
kernel_string = "extern \"C\" {\n" + kernel_string + "\n}"
184+
# this basically checks if we aren't compiling Fortran
185+
# at the moment any C, C++, or CUDA code is assumed to use extern "C" linkage
186+
if ".c" in suffix and 'extern "C"' not in kernel_string:
187+
kernel_string = 'extern "C" {\n' + kernel_string + "\n}"
171188

172-
#copy user specified compiler options to current list
189+
# copy user specified compiler options to current list
173190
if self.compiler_options:
174191
compiler_options += self.compiler_options
175192

176193
lib_args = []
177194
if "CL/cl.h" in kernel_string:
178195
lib_args = ["-lOpenCL"]
179196

180-
logging.debug('using compiler ' + self.compiler)
181-
logging.debug('compiler_options ' + " ".join(compiler_options))
182-
logging.debug('lib_args ' + " ".join(lib_args))
197+
logging.debug("using compiler " + self.compiler)
198+
logging.debug("compiler_options " + " ".join(compiler_options))
199+
logging.debug("lib_args " + " ".join(lib_args))
183200

184201
source_file = get_temp_filename(suffix=suffix)
185202
filename = ".".join(source_file.split(".")[:-1])
186203

187-
#detect Fortran modules
204+
# detect Fortran modules
188205
match = re.search(r"\s*module\s+([a-zA-Z_]*)", kernel_string)
189206
if match:
190207
if self.compiler == "gfortran":
@@ -194,7 +211,7 @@ def compile(self, kernel_instance):
194211
elif self.compiler == "pgfortran":
195212
kernel_name = match.group(1) + "_" + kernel_name + "_"
196213
else:
197-
#for functions outside of modules
214+
# for functions outside of modules
198215
if self.compiler in ["gfortran", "ftn", "ifort", "pgfortran"]:
199216
kernel_name = kernel_name + "_"
200217

@@ -205,44 +222,52 @@ def compile(self, kernel_instance):
205222
if platform.system() == "Darwin":
206223
lib_extension = ".dylib"
207224

208-
subprocess.check_call([self.compiler, "-c", source_file] + compiler_options + ["-o", filename + ".o"])
209-
subprocess.check_call([self.compiler, filename + ".o"] + compiler_options + ["-shared", "-o", filename + lib_extension] + lib_args)
210-
211-
self.lib = np.ctypeslib.load_library(filename, '.')
225+
subprocess.check_call(
226+
[self.compiler, "-c", source_file]
227+
+ compiler_options
228+
+ ["-o", filename + ".o"]
229+
)
230+
subprocess.check_call(
231+
[self.compiler, filename + ".o"]
232+
+ compiler_options
233+
+ ["-shared", "-o", filename + lib_extension]
234+
+ lib_args
235+
)
236+
237+
self.lib = np.ctypeslib.load_library(filename, ".")
212238
func = getattr(self.lib, kernel_name)
213239
func.restype = C.c_float
214240

215241
finally:
216242
delete_temp_file(source_file)
217-
delete_temp_file(filename+".o")
218-
delete_temp_file(filename+".so")
219-
delete_temp_file(filename+".dylib")
243+
delete_temp_file(filename + ".o")
244+
delete_temp_file(filename + ".so")
245+
delete_temp_file(filename + ".dylib")
220246

221247
return func
222248

223-
224249
def start_event(self):
225-
""" Records the event that marks the start of a measurement
250+
"""Records the event that marks the start of a measurement
226251
227-
C backend does not use events """
252+
C backend does not use events"""
228253
pass
229254

230255
def stop_event(self):
231-
""" Records the event that marks the end of a measurement
256+
"""Records the event that marks the end of a measurement
232257
233-
C backend does not use events """
258+
C backend does not use events"""
234259
pass
235260

236261
def kernel_finished(self):
237-
""" Returns True if the kernel has finished, False otherwise
262+
"""Returns True if the kernel has finished, False otherwise
238263
239-
C backend does not support asynchronous launches """
264+
C backend does not support asynchronous launches"""
240265
return True
241266

242267
def synchronize(self):
243-
""" Halts execution until device has finished its tasks
268+
"""Halts execution until device has finished its tasks
244269
245-
C backend does not support asynchronous launches """
270+
C backend does not support asynchronous launches"""
246271
pass
247272

248273
def run_kernel(self, func, c_args, threads, grid):
@@ -312,11 +337,11 @@ def memcpy_htod(self, dest, src):
312337
dest.numpy[:] = src
313338

314339
def cleanup_lib(self):
315-
""" unload the previously loaded shared library """
340+
"""unload the previously loaded shared library"""
316341
if not self.using_openmp:
317-
#this if statement is necessary because shared libraries that use
318-
#OpenMP will core dump when unloaded, this is a well-known issue with OpenMP
319-
logging.debug('unloading shared library')
342+
# this if statement is necessary because shared libraries that use
343+
# OpenMP will core dump when unloaded, this is a well-known issue with OpenMP
344+
logging.debug("unloading shared library")
320345
_ctypes.dlclose(self.lib._handle)
321346

322347
units = {}

0 commit comments

Comments
 (0)