@@ -128,8 +128,8 @@ def ready_argument_list(self, arguments):
128128 The order should match the argument list on the HIP function.
129129 Allowed values are np.ndarray, and/or np.int32, np.float32, and so on.
130130 :type arguments: list(numpy objects)
131- :returns: List of ctypes arguments to be passed to the HIP function.
132- :rtype: list of ctypes
131+ :returns: Ctypes structure of arguments to be passed to the HIP function.
132+ :rtype: ctypes structure
133133 """
134134 logging .debug ("HipFunction ready_argument_list called" )
135135
@@ -150,9 +150,17 @@ def ready_argument_list(self, arguments):
150150 elif isinstance (arg , np .generic ):
151151 data_ctypes = dtype_map [dtype_str ](arg )
152152 ctype_args .append (data_ctypes )
153+
154+ # Determine the types of the fields in the structure
155+ field_types = [type (x ) for x in ctype_args ]
156+ # Define a new ctypes structure with the inferred layout
157+ class ArgListStructure (ctypes .Structure ):
158+ _fields_ = [(f'field{ i } ' , t ) for i , t in enumerate (field_types )]
159+ def __getitem__ (self , key ):
160+ return self ._fields_ [key ]
153161
154- return ctype_args
155-
162+ return ArgListStructure ( * ctype_args )
163+
156164
157165 def compile (self , kernel_instance ):
158166 """call the HIP compiler to compile the kernel, return the function
@@ -219,10 +227,10 @@ def run_kernel(self, func, gpu_args, threads, grid, stream=None):
219227 :param func: A PyHIP kernel compiled for this specific kernel configuration
220228 :type func: ctypes pionter
221229
222- :param gpu_args: A list of arguments to the kernel, order should match the
230+ :param gpu_args: A ctypes structure of arguments to the kernel, order should match the
223231 order in the code. Allowed values are either variables in global memory
224232 or single values passed by value.
225- :type gpu_args: list of ctypes
233+ :type gpu_args: ctypes structure
226234
227235 :param threads: A tuple listing the number of threads in each dimension of
228236 the thread block
@@ -233,17 +241,10 @@ def run_kernel(self, func, gpu_args, threads, grid, stream=None):
233241 :type grid: tuple(int, int, int)
234242 """
235243 logging .debug ("HipFunction run_kernel called" )
244+
236245 if stream is None :
237246 stream = self .stream
238247
239- # Determine the types of the fields in the structure
240- field_types = [type (x ) for x in gpu_args ]
241- # Define a new ctypes structure with the inferred layout
242- class ArgListStructure (ctypes .Structure ):
243- _fields_ = [(f'field{ i } ' , t ) for i , t in enumerate (field_types )]
244-
245- gpu_args = ArgListStructure (* gpu_args )
246-
247248 hip .hipModuleLaunchKernel (func ,
248249 grid [0 ], grid [1 ], grid [2 ],
249250 threads [0 ], threads [1 ], threads [2 ],
0 commit comments