6464 "float64" : ctypes .c_double ,
6565}
6666
67- # define arguments and return value ctypes for hipEventQuery
67+ # define arguments and return value types of HIP functions
6868_libhip .hipEventQuery .restype = ctypes .c_int
6969_libhip .hipEventQuery .argtypes = [ctypes .c_void_p ]
70+ _libhip .hipModuleGetGlobal .restype = ctypes .c_int
71+ _libhip .hipModuleGetGlobal .argtypes = [ctypes .c_void_p , ctypes .c_size_t , ctypes .c_void_p , ctypes .c_char_p ]
72+ _libhip .hipMemset .restype = ctypes .c_int
73+ _libhip .hipModuleGetGlobal .argtypes = [ctypes .c_void_p , ctypes .c_int , ctypes .c_size_t ]
74+
7075
7176hipSuccess = 0
7277
@@ -108,6 +113,7 @@ def __init__(self, device=0, iterations=7, compiler_options=None, observers=None
108113 self .end = hip .hipEventCreate ()
109114
110115 self .smem_size = 0
116+ self .current_module = None
111117
112118 # setup observers
113119 self .observers = observers or []
@@ -125,20 +131,22 @@ def ready_argument_list(self, arguments):
125131 :returns: A ctypes structure that can be passed to the HIP function.
126132 :rtype: ctypes.Structure
127133 """
128-
129134 logging .debug ("HipFunction ready_argument_list called" )
135+
130136 ctype_args = []
131137 data_ctypes = None
132138 for arg in arguments :
133139 dtype_str = str (arg .dtype )
140+ # Allocate space on device for array and convert to ctypes
134141 if isinstance (arg , np .ndarray ):
135142 if dtype_str in dtype_map .keys ():
136143 device_ptr = hip .hipMalloc (arg .nbytes )
137144 data_ctypes = arg .ctypes .data_as (ctypes .POINTER (dtype_map [dtype_str ]))
138145 hip .hipMemcpy_htod (device_ptr , data_ctypes , arg .nbytes )
139146 ctype_args .append (device_ptr )
140147 else :
141- raise TypeError ("unknown dtype for ndarray" )
148+ raise TypeError ("unknown dtype for ndarray" )
149+ # Convert valid non-array arguments to ctypes
142150 elif isinstance (arg , np .generic ):
143151 data_ctypes = dtype_map [dtype_str ](arg )
144152 ctype_args .append (data_ctypes )
@@ -180,6 +188,7 @@ def compile(self, kernel_instance):
180188 hiprtc .hiprtcCompileProgram (kernel_ptr , [])
181189 code = hiprtc .hiprtcGetCode (kernel_ptr )
182190 module = hip .hipModuleLoadData (code )
191+ self .current_module = module
183192 kernel = hip .hipModuleGetFunction (module , kernel_name )
184193
185194 return kernel
@@ -242,17 +251,20 @@ def run_kernel(self, func, gpu_args, threads, grid, stream=None):
242251 def memset (self , allocation , value , size ):
243252 """set the memory in allocation to the value in value
244253
245- :param allocation: An Argument for some memory allocation unit
254+ :param allocation: A GPU memory allocation unit
246255 :type allocation: ctypes ptr
247256
248257 :param value: The value to set the memory to
249258 :type value: a single 8-bit unsigned int
250259
251260 :param size: The size of to the allocation unit in bytes
252261 :type size: int
262+
253263 """
254- logging .debug ("HipFunction memset called" )
255- allocation .contents .value = value # probably wrong, still have to look into this
264+ ctypes_value = ctypes .c_int (value )
265+ ctypes_size = ctypes .c_size_t (size )
266+ status = _libhip .hipMemset (allocation , ctypes_value , ctypes_size )
267+ hip .hipCheckStatus (status )
256268
257269 def memcpy_dtoh (self , dest , src ):
258270 """perform a device to host memory copy
@@ -281,8 +293,25 @@ def memcpy_htod(self, dest, src):
281293 hip .hipMemcpy_htod (dest , ctypes .byref (src .ctypes ), ctypes .sizeof (dtype_map [dtype_str ]) * src .size )
282294
283295 def copy_constant_memory_args (self , cmem_args ):
284- """This method must implement the allocation and copy of constant memory to the GPU."""
296+ """adds constant memory arguments to the most recently compiled module
297+
298+ :param cmem_args: A dictionary containing the data to be passed to the
299+ device constant memory. The format to be used is as follows: A
300+ string key is used to name the constant memory symbol to which the
301+ value needs to be copied. Similar to regular arguments, these need
302+ to be numpy objects, such as numpy.ndarray or numpy.int32, and so on.
303+ :type cmem_args: dict( string: numpy.ndarray, ... )
304+ """
285305 logging .debug ("HipFunction copy_constant_memory_args called" )
306+ logging .debug ("current module: " + str (self .current_module ))
307+
308+ for k , v in cmem_args .items ():
309+ symbol = ctypes .c_void_p
310+ size_kernel = ctypes .c_size_t
311+ status = _libhip .hipModuleGetGlobal (symbol , size_kernel , self .current_module , str .encode (k ))
312+ hip .hipCheckStatus (status )
313+ dtype_str = str (v .dtype )
314+ hip .hipMemcpy_htod (symbol , ctypes .byref (v .ctypes ), ctypes .sizeof (dtype_map [dtype_str ]) * v .size )
286315
287316 def copy_shared_memory_args (self , smem_args ):
288317 """This method must implement the dynamic allocation of shared memory on the GPU."""
@@ -292,5 +321,6 @@ def copy_shared_memory_args(self, smem_args):
292321 def copy_texture_memory_args (self , texmem_args ):
293322 """This method must implement the allocation and copy of texture memory to the GPU."""
294323 logging .debug ("HipFunction copy_texture_memory_args called" )
324+ # NOT SUPPORTED?
295325
296326 units = {"time" : "ms" }
0 commit comments