Skip to content

Commit ce4dc1e

Browse files
authored
Merge pull request #1 from MiloLurati/PyHIPCov
PyHIP Coverage
2 parents ca0aed0 + a37e302 commit ce4dc1e

File tree

1 file changed

+5
-55
lines changed

1 file changed

+5
-55
lines changed

kernel_tuner/backends/hip.py

Lines changed: 5 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -9,29 +9,6 @@
99
from kernel_tuner.backends.backend import GPUBackend
1010
from kernel_tuner.observers.hip import HipRuntimeObserver
1111

12-
_libhip = None
13-
_hip_platform_name = ''
14-
15-
# Try to find amd hip library, if not found, fallback to nvhip library
16-
if 'linux' in sys.platform:
17-
try:
18-
_libhip_libname = 'libamdhip64.so'
19-
_libhip = ctypes.cdll.LoadLibrary(_libhip_libname)
20-
_hip_platform_name = 'amd'
21-
except:
22-
try:
23-
_libhip_libname = 'libnvhip64.so'
24-
_libhip = ctypes.cdll.LoadLibrary(_libhip_libname)
25-
_hip_platform_name = 'nvidia'
26-
except:
27-
raise RuntimeError(
28-
'cant find libamdhip64.so or libnvhip64.so. make sure LD_LIBRARY_PATH is set')
29-
30-
else:
31-
# Currently we do not support windows
32-
raise RuntimeError('Only linux is supported')
33-
34-
3512
# embedded in try block to be able to generate documentation
3613
# and run tests without pyhip installed
3714
try:
@@ -55,14 +32,6 @@
5532
"float64": ctypes.c_double,
5633
}
5734

58-
# define arguments and return value types of HIP functions
59-
_libhip.hipEventQuery.restype = ctypes.c_int
60-
_libhip.hipEventQuery.argtypes = [ctypes.c_void_p]
61-
_libhip.hipModuleGetGlobal.restype = ctypes.c_int
62-
_libhip.hipModuleGetGlobal.argtypes = [ctypes.POINTER(ctypes.c_void_p), ctypes.POINTER(ctypes.c_size_t), ctypes.c_void_p, ctypes.c_char_p]
63-
_libhip.hipMemset.restype = ctypes.c_int
64-
_libhip.hipMemset.argtypes = [ctypes.c_void_p, ctypes.c_int, ctypes.c_size_t]
65-
6635
hipSuccess = 0
6736

6837
class HipFunctions(GPUBackend):
@@ -210,11 +179,7 @@ def kernel_finished(self):
210179
logging.debug("HipFunction kernel_finished called")
211180

212181
# Query the status of the event
213-
status = _libhip.hipEventQuery(self.end)
214-
if status == hipSuccess:
215-
return True
216-
else:
217-
return False
182+
return hip.hipEventQuery(self.end)
218183

219184
def synchronize(self):
220185
"""Halts execution until device has finished its tasks"""
@@ -268,12 +233,7 @@ def memset(self, allocation, value, size):
268233
"""
269234
logging.debug("HipFunction memset called")
270235

271-
# Format arguments to correct type, set the memory and
272-
# check return value of memset (as done in PyHIP with hipCheckStatus)
273-
ctypes_value = ctypes.c_int(value)
274-
ctypes_size = ctypes.c_size_t(size)
275-
status = _libhip.hipMemset(allocation, ctypes_value, ctypes_size)
276-
hip.hipCheckStatus(status)
236+
hip.hipMemset(allocation, value, size)
277237

278238
def memcpy_dtoh(self, dest, src):
279239
"""perform a device to host memory copy
@@ -321,23 +281,13 @@ def copy_constant_memory_args(self, cmem_args):
321281

322282
# Iterate over dictionary
323283
for k, v in cmem_args.items():
324-
# Format arguments, call hipModuleGetGlobal,
325-
# and check return status (as done in PyHIP with hipCheckStatus)
326-
symbol_string = ctypes.c_char_p(k.encode('utf-8'))
327-
symbol = ctypes.c_void_p()
328-
symbol_ptr = ctypes.POINTER(ctypes.c_void_p)(symbol)
329-
size_kernel = ctypes.c_size_t(0)
330-
331-
# Get constant memory symbol and check return value of hipModuleGetGlobal
332-
# (as done in PyHIP with hipCheckStatus)
333-
size_kernel_ptr = ctypes.POINTER(ctypes.c_size_t)(size_kernel)
334-
status = _libhip.hipModuleGetGlobal(symbol_ptr, size_kernel_ptr, self.current_module, symbol_string)
335-
hip.hipCheckStatus(status)
284+
#Get symbol pointer
285+
symbol_ptr, _ = hip.hipModuleGetGlobal(self.current_module, k)
336286

337287
#Format arguments and perform memory copy
338288
dtype_str = str(v.dtype)
339289
v_c = v.ctypes.data_as(ctypes.POINTER(dtype_map[dtype_str]))
340-
hip.hipMemcpy_htod(symbol_ptr.contents, v_c, v.nbytes)
290+
hip.hipMemcpy_htod(symbol_ptr, v_c, v.nbytes)
341291

342292
def copy_shared_memory_args(self, smem_args):
343293
"""add shared memory arguments to the kernel"""

0 commit comments

Comments
 (0)