Skip to content

Commit f6cc000

Browse files
committed
modified (not tested) memset, copy_constant_memory_args
1 parent e003d1c commit f6cc000

File tree

1 file changed

+37
-7
lines changed

1 file changed

+37
-7
lines changed

kernel_tuner/backends/hip.py

Lines changed: 37 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,14 @@
6464
"float64": ctypes.c_double,
6565
}
6666

67-
# define arguments and return value ctypes for hipEventQuery
67+
# define arguments and return value types of HIP functions
6868
_libhip.hipEventQuery.restype = ctypes.c_int
6969
_libhip.hipEventQuery.argtypes = [ctypes.c_void_p]
70+
_libhip.hipModuleGetGlobal.restype = ctypes.c_int
71+
_libhip.hipModuleGetGlobal.argtypes = [ctypes.c_void_p, ctypes.c_size_t, ctypes.c_void_p, ctypes.c_char_p]
72+
_libhip.hipMemset.restype = ctypes.c_int
73+
_libhip.hipModuleGetGlobal.argtypes = [ctypes.c_void_p, ctypes.c_int, ctypes.c_size_t]
74+
7075

7176
hipSuccess = 0
7277

@@ -108,6 +113,7 @@ def __init__(self, device=0, iterations=7, compiler_options=None, observers=None
108113
self.end = hip.hipEventCreate()
109114

110115
self.smem_size = 0
116+
self.current_module = None
111117

112118
# setup observers
113119
self.observers = observers or []
@@ -125,20 +131,22 @@ def ready_argument_list(self, arguments):
125131
:returns: A ctypes structure that can be passed to the HIP function.
126132
:rtype: ctypes.Structure
127133
"""
128-
129134
logging.debug("HipFunction ready_argument_list called")
135+
130136
ctype_args = []
131137
data_ctypes = None
132138
for arg in arguments:
133139
dtype_str = str(arg.dtype)
140+
# Allocate space on device for array and convert to ctypes
134141
if isinstance(arg, np.ndarray):
135142
if dtype_str in dtype_map.keys():
136143
device_ptr = hip.hipMalloc(arg.nbytes)
137144
data_ctypes = arg.ctypes.data_as(ctypes.POINTER(dtype_map[dtype_str]))
138145
hip.hipMemcpy_htod(device_ptr, data_ctypes, arg.nbytes)
139146
ctype_args.append(device_ptr)
140147
else:
141-
raise TypeError("unknown dtype for ndarray")
148+
raise TypeError("unknown dtype for ndarray")
149+
# Convert valid non-array arguments to ctypes
142150
elif isinstance(arg, np.generic):
143151
data_ctypes = dtype_map[dtype_str](arg)
144152
ctype_args.append(data_ctypes)
@@ -180,6 +188,7 @@ def compile(self, kernel_instance):
180188
hiprtc.hiprtcCompileProgram(kernel_ptr, [])
181189
code = hiprtc.hiprtcGetCode(kernel_ptr)
182190
module = hip.hipModuleLoadData(code)
191+
self.current_module = module
183192
kernel = hip.hipModuleGetFunction(module, kernel_name)
184193

185194
return kernel
@@ -242,17 +251,20 @@ def run_kernel(self, func, gpu_args, threads, grid, stream=None):
242251
def memset(self, allocation, value, size):
243252
"""set the memory in allocation to the value in value
244253
245-
:param allocation: An Argument for some memory allocation unit
254+
:param allocation: A GPU memory allocation unit
246255
:type allocation: ctypes ptr
247256
248257
:param value: The value to set the memory to
249258
:type value: a single 8-bit unsigned int
250259
251260
:param size: The size of to the allocation unit in bytes
252261
:type size: int
262+
253263
"""
254-
logging.debug("HipFunction memset called")
255-
allocation.contents.value = value # probably wrong, still have to look into this
264+
ctypes_value = ctypes.c_int(value)
265+
ctypes_size = ctypes.c_size_t(size)
266+
status = _libhip.hipMemset(allocation, ctypes_value, ctypes_size)
267+
hip.hipCheckStatus(status)
256268

257269
def memcpy_dtoh(self, dest, src):
258270
"""perform a device to host memory copy
@@ -281,8 +293,25 @@ def memcpy_htod(self, dest, src):
281293
hip.hipMemcpy_htod(dest, ctypes.byref(src.ctypes), ctypes.sizeof(dtype_map[dtype_str]) * src.size)
282294

283295
def copy_constant_memory_args(self, cmem_args):
284-
"""This method must implement the allocation and copy of constant memory to the GPU."""
296+
"""adds constant memory arguments to the most recently compiled module
297+
298+
:param cmem_args: A dictionary containing the data to be passed to the
299+
device constant memory. The format to be used is as follows: A
300+
string key is used to name the constant memory symbol to which the
301+
value needs to be copied. Similar to regular arguments, these need
302+
to be numpy objects, such as numpy.ndarray or numpy.int32, and so on.
303+
:type cmem_args: dict( string: numpy.ndarray, ... )
304+
"""
285305
logging.debug("HipFunction copy_constant_memory_args called")
306+
logging.debug("current module: " + str(self.current_module))
307+
308+
for k, v in cmem_args.items():
309+
symbol = ctypes.c_void_p
310+
size_kernel = ctypes.c_size_t
311+
status = _libhip.hipModuleGetGlobal(symbol, size_kernel, self.current_module, str.encode(k))
312+
hip.hipCheckStatus(status)
313+
dtype_str = str(v.dtype)
314+
hip.hipMemcpy_htod(symbol, ctypes.byref(v.ctypes), ctypes.sizeof(dtype_map[dtype_str]) * v.size)
286315

287316
def copy_shared_memory_args(self, smem_args):
288317
"""This method must implement the dynamic allocation of shared memory on the GPU."""
@@ -292,5 +321,6 @@ def copy_shared_memory_args(self, smem_args):
292321
def copy_texture_memory_args(self, texmem_args):
293322
"""This method must implement the allocation and copy of texture memory to the GPU."""
294323
logging.debug("HipFunction copy_texture_memory_args called")
324+
# NOT SUPPORTED?
295325

296326
units = {"time": "ms"}

0 commit comments

Comments
 (0)