Skip to content

Commit 5ae62c6

Browse files
committed
changed once again how gpu arguments get passed around --> ctypes structure
1 parent 237f39b commit 5ae62c6

File tree

2 files changed

+17
-19
lines changed

2 files changed

+17
-19
lines changed

kernel_tuner/backends/hip.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -128,8 +128,8 @@ def ready_argument_list(self, arguments):
128128
The order should match the argument list on the HIP function.
129129
Allowed values are np.ndarray, and/or np.int32, np.float32, and so on.
130130
:type arguments: list(numpy objects)
131-
:returns: List of ctypes arguments to be passed to the HIP function.
132-
:rtype: list of ctypes
131+
:returns: Ctypes structure of arguments to be passed to the HIP function.
132+
:rtype: ctypes structure
133133
"""
134134
logging.debug("HipFunction ready_argument_list called")
135135

@@ -150,9 +150,17 @@ def ready_argument_list(self, arguments):
150150
elif isinstance(arg, np.generic):
151151
data_ctypes = dtype_map[dtype_str](arg)
152152
ctype_args.append(data_ctypes)
153+
154+
# Determine the types of the fields in the structure
155+
field_types = [type(x) for x in ctype_args]
156+
# Define a new ctypes structure with the inferred layout
157+
class ArgListStructure(ctypes.Structure):
158+
_fields_ = [(f'field{i}', t) for i, t in enumerate(field_types)]
159+
def __getitem__(self, key):
160+
return self._fields_[key]
153161

154-
return ctype_args
155-
162+
return ArgListStructure(*ctype_args)
163+
156164

157165
def compile(self, kernel_instance):
158166
"""call the HIP compiler to compile the kernel, return the function
@@ -219,10 +227,10 @@ def run_kernel(self, func, gpu_args, threads, grid, stream=None):
219227
:param func: A PyHIP kernel compiled for this specific kernel configuration
220228
:type func: ctypes pionter
221229
222-
:param gpu_args: A list of arguments to the kernel, order should match the
230+
:param gpu_args: A ctypes structure of arguments to the kernel, order should match the
223231
order in the code. Allowed values are either variables in global memory
224232
or single values passed by value.
225-
:type gpu_args: list of ctypes
233+
:type gpu_args: ctypes structure
226234
227235
:param threads: A tuple listing the number of threads in each dimension of
228236
the thread block
@@ -233,17 +241,10 @@ def run_kernel(self, func, gpu_args, threads, grid, stream=None):
233241
:type grid: tuple(int, int, int)
234242
"""
235243
logging.debug("HipFunction run_kernel called")
244+
236245
if stream is None:
237246
stream = self.stream
238247

239-
# Determine the types of the fields in the structure
240-
field_types = [type(x) for x in gpu_args]
241-
# Define a new ctypes structure with the inferred layout
242-
class ArgListStructure(ctypes.Structure):
243-
_fields_ = [(f'field{i}', t) for i, t in enumerate(field_types)]
244-
245-
gpu_args = ArgListStructure(*gpu_args)
246-
247248
hip.hipModuleLaunchKernel(func,
248249
grid[0], grid[1], grid[2],
249250
threads[0], threads[1], threads[2],

test/test_hip_functions.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,8 @@ def test_ready_argument_list():
2525

2626
dev = kt_hip.HipFunctions(0)
2727
gpu_args = dev.ready_argument_list(arguments)
28-
29-
assert isinstance(gpu_args[0], ctypes.c_void_p)
30-
assert isinstance(gpu_args[1], ctypes.c_int32)
31-
assert isinstance(gpu_args[2], ctypes.c_void_p)
32-
assert isinstance(gpu_args[3], ctypes.c_bool)
28+
29+
assert(gpu_args, ctypes.Structure)
3330

3431
@skip_if_no_pyhip
3532
def test_compile():

0 commit comments

Comments
 (0)