Skip to content

Commit b5d92d2

Browse files
committed
Importing the correct modules, and using them.
1 parent 331c1f1 commit b5d92d2

File tree

1 file changed

+34
-34
lines changed

1 file changed

+34
-34
lines changed

kernel_tuner/backends/nvcuda.py

Lines changed: 34 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
try:
1313
from cuda.bindings import driver, runtime, nvrtc
1414
except ImportError:
15-
cuda = None
15+
driver = None
1616

1717

1818
class CudaFunctions(GPUBackend):
@@ -38,34 +38,34 @@ def __init__(self, device=0, iterations=7, compiler_options=None, observers=None
3838
"""
3939
self.allocations = []
4040
self.texrefs = []
41-
if not cuda:
41+
if not driver:
4242
raise ImportError(
4343
"cuda-python not installed, install using 'pip install cuda-python', or check https://kerneltuner.github.io/kernel_tuner/stable/install.html#cuda-and-pycuda."
4444
)
4545

4646
# initialize and select device
47-
err = cuda.cuInit(0)
47+
err = driver.cuInit(0)
4848
cuda_error_check(err)
4949
err, self.device = cuda.cuDeviceGet(device)
5050
cuda_error_check(err)
51-
err, self.context = cuda.cuDevicePrimaryCtxRetain(device)
51+
err, self.context = driver.cuDevicePrimaryCtxRetain(device)
5252
cuda_error_check(err)
5353
if CudaFunctions.last_selected_device != device:
54-
err = cuda.cuCtxSetCurrent(self.context)
54+
err = driver.cuCtxSetCurrent(self.context)
5555
cuda_error_check(err)
5656
CudaFunctions.last_selected_device = device
5757

5858
# compute capabilities and device properties
59-
err, major = cudart.cudaDeviceGetAttribute(
60-
cudart.cudaDeviceAttr.cudaDevAttrComputeCapabilityMajor, device
59+
err, major = runtime.cudaDeviceGetAttribute(
60+
runtime.cudaDeviceAttr.cudaDevAttrComputeCapabilityMajor, device
6161
)
6262
cuda_error_check(err)
63-
err, minor = cudart.cudaDeviceGetAttribute(
64-
cudart.cudaDeviceAttr.cudaDevAttrComputeCapabilityMinor, device
63+
err, minor = runtime.cudaDeviceGetAttribute(
64+
runtime.cudaDeviceAttr.cudaDevAttrComputeCapabilityMinor, device
6565
)
6666
cuda_error_check(err)
67-
err, self.max_threads = cudart.cudaDeviceGetAttribute(
68-
cudart.cudaDeviceAttr.cudaDevAttrMaxThreadsPerBlock, device
67+
err, self.max_threads = runtime.cudaDeviceGetAttribute(
68+
runtime.cudaDeviceAttr.cudaDevAttrMaxThreadsPerBlock, device
6969
)
7070
cuda_error_check(err)
7171
self.cc = f"{major}{minor}"
@@ -78,11 +78,11 @@ def __init__(self, device=0, iterations=7, compiler_options=None, observers=None
7878
self.compiler_options_bytes.append(str(option).encode("UTF-8"))
7979

8080
# create a stream and events
81-
err, self.stream = cuda.cuStreamCreate(0)
81+
err, self.stream = driver.cuStreamCreate(0)
8282
cuda_error_check(err)
83-
err, self.start = cuda.cuEventCreate(0)
83+
err, self.start = driver.cuEventCreate(0)
8484
cuda_error_check(err)
85-
err, self.end = cuda.cuEventCreate(0)
85+
err, self.end = driver.cuEventCreate(0)
8686
cuda_error_check(err)
8787

8888
# default dynamically allocated shared memory size, can be overwritten using smem_args
@@ -95,7 +95,7 @@ def __init__(self, device=0, iterations=7, compiler_options=None, observers=None
9595
observer.register_device(self)
9696

9797
# collect environment information
98-
err, device_properties = cudart.cudaGetDeviceProperties(device)
98+
err, device_properties = runtime.cudaGetDeviceProperties(device)
9999
cuda_error_check(err)
100100
env = dict()
101101
env["device_name"] = device_properties.name.decode()
@@ -109,8 +109,8 @@ def __init__(self, device=0, iterations=7, compiler_options=None, observers=None
109109

110110
def __del__(self):
111111
for device_memory in self.allocations:
112-
if isinstance(device_memory, cuda.CUdeviceptr):
113-
err = cuda.cuMemFree(device_memory)
112+
if isinstance(device_memory, driver.CUdeviceptr):
113+
err = driver.cuMemFree(device_memory)
114114
cuda_error_check(err)
115115

116116
def ready_argument_list(self, arguments):
@@ -128,7 +128,7 @@ def ready_argument_list(self, arguments):
128128
for arg in arguments:
129129
# if arg is a numpy array copy it to device
130130
if isinstance(arg, np.ndarray):
131-
err, device_memory = cuda.cuMemAlloc(arg.nbytes)
131+
err, device_memory = driver.cuMemAlloc(arg.nbytes)
132132
cuda_error_check(err)
133133
self.allocations.append(device_memory)
134134
gpu_args.append(device_memory)
@@ -184,18 +184,18 @@ def compile(self, kernel_instance):
184184
buff = b" " * size
185185
err = nvrtc.nvrtcGetPTX(program, buff)
186186
cuda_error_check(err)
187-
err, self.current_module = cuda.cuModuleLoadData(np.char.array(buff))
188-
if err == cuda.CUresult.CUDA_ERROR_INVALID_PTX:
187+
err, self.current_module = driver.cuModuleLoadData(np.char.array(buff))
188+
if err == driver.CUresult.CUDA_ERROR_INVALID_PTX:
189189
raise SkippableFailure("uses too much shared data")
190190
else:
191191
cuda_error_check(err)
192-
err, self.func = cuda.cuModuleGetFunction(
192+
err, self.func = driver.cuModuleGetFunction(
193193
self.current_module, str.encode(kernel_name)
194194
)
195195
cuda_error_check(err)
196196

197197
# get the number of registers per thread used in this kernel
198-
num_regs = cuda.cuFuncGetAttribute(cuda.CUfunction_attribute.CU_FUNC_ATTRIBUTE_NUM_REGS, self.func)
198+
num_regs = driver.cuFuncGetAttribute(driver.CUfunction_attribute.CU_FUNC_ATTRIBUTE_NUM_REGS, self.func)
199199
assert num_regs[0] == 0, f"Retrieving number of registers per thread unsuccesful: code {num_regs[0]}"
200200
self.num_regs = num_regs[1]
201201

@@ -210,26 +210,26 @@ def compile(self, kernel_instance):
210210

211211
def start_event(self):
212212
"""Records the event that marks the start of a measurement."""
213-
err = cudart.cudaEventRecord(self.start, self.stream)
213+
err = runtime.cudaEventRecord(self.start, self.stream)
214214
cuda_error_check(err)
215215

216216
def stop_event(self):
217217
"""Records the event that marks the end of a measurement."""
218-
err = cudart.cudaEventRecord(self.end, self.stream)
218+
err = runtime.cudaEventRecord(self.end, self.stream)
219219
cuda_error_check(err)
220220

221221
def kernel_finished(self):
222222
"""Returns True if the kernel has finished, False otherwise."""
223-
err = cudart.cudaEventQuery(self.end)
224-
if err[0] == cudart.cudaError_t.cudaSuccess:
223+
err = runtime.cudaEventQuery(self.end)
224+
if err[0] == runtime.cudaError_t.cudaSuccess:
225225
return True
226226
else:
227227
return False
228228

229229
@staticmethod
230230
def synchronize():
231231
"""Halts execution until device has finished its tasks."""
232-
err = cudart.cudaDeviceSynchronize()
232+
err = runtime.cudaDeviceSynchronize()
233233
cuda_error_check(err)
234234

235235
def copy_constant_memory_args(self, cmem_args):
@@ -243,9 +243,9 @@ def copy_constant_memory_args(self, cmem_args):
243243
:type cmem_args: dict( string: numpy.ndarray, ... )
244244
"""
245245
for k, v in cmem_args.items():
246-
err, symbol, _ = cuda.cuModuleGetGlobal(self.current_module, str.encode(k))
246+
err, symbol, _ = driver.cuModuleGetGlobal(self.current_module, str.encode(k))
247247
cuda_error_check(err)
248-
err = cuda.cuMemcpyHtoD(symbol, v, v.nbytes)
248+
err = driver.cuMemcpyHtoD(symbol, v, v.nbytes)
249249
cuda_error_check(err)
250250

251251
def copy_shared_memory_args(self, smem_args):
@@ -284,12 +284,12 @@ def run_kernel(self, func, gpu_args, threads, grid, stream=None):
284284
stream = self.stream
285285
arg_types = list()
286286
for arg in gpu_args:
287-
if isinstance(arg, cuda.CUdeviceptr):
287+
if isinstance(arg, driver.CUdeviceptr):
288288
arg_types.append(None)
289289
else:
290290
arg_types.append(np.ctypeslib.as_ctypes_type(arg.dtype))
291291
kernel_args = (tuple(gpu_args), tuple(arg_types))
292-
err = cuda.cuLaunchKernel(
292+
err = driver.cuLaunchKernel(
293293
func,
294294
grid[0],
295295
grid[1],
@@ -318,7 +318,7 @@ def memset(allocation, value, size):
318318
:type size: int
319319
320320
"""
321-
err = cudart.cudaMemset(allocation, value, size)
321+
err = runtime.cudaMemset(allocation, value, size)
322322
cuda_error_check(err)
323323

324324
@staticmethod
@@ -331,7 +331,7 @@ def memcpy_dtoh(dest, src):
331331
:param src: A GPU memory allocation unit
332332
:type src: cuda.CUdeviceptr
333333
"""
334-
err = cuda.cuMemcpyDtoH(dest, src, dest.nbytes)
334+
err = driver.cuMemcpyDtoH(dest, src, dest.nbytes)
335335
cuda_error_check(err)
336336

337337
@staticmethod
@@ -344,7 +344,7 @@ def memcpy_htod(dest, src):
344344
:param src: A numpy array in host memory to store the data
345345
:type src: numpy.ndarray
346346
"""
347-
err = cuda.cuMemcpyHtoD(dest, src, src.nbytes)
347+
err = driver.cuMemcpyHtoD(dest, src, src.nbytes)
348348
cuda_error_check(err)
349349

350350
units = {"time": "ms"}

0 commit comments

Comments
 (0)