3232
3333hipSuccess = 0
3434
35+
3536def hip_check (call_result ):
37+ """helper function to check return values of hip calls"""
3638 err = call_result [0 ]
3739 result = call_result [1 :]
3840 if len (result ) == 1 :
@@ -41,6 +43,7 @@ def hip_check(call_result):
4143 raise RuntimeError (str (err ))
4244 return result
4345
46+
4447class HipFunctions (GPUBackend ):
4548 """Class that groups the HIP functions on maintains state about the device."""
4649
@@ -59,7 +62,9 @@ def __init__(self, device=0, iterations=7, compiler_options=None, observers=None
5962 :type iterations: int
6063 """
6164 if not hip or not hiprtc :
62- raise ImportError ("Unable to import HIP Python, check https://kerneltuner.github.io/kernel_tuner/stable/install.html#hip-and-hip-python." )
65+ raise ImportError (
66+ "Unable to import HIP Python, check https://kerneltuner.github.io/kernel_tuner/stable/install.html#hip-and-hip-python."
67+ )
6368
6469 # embedded in try block to be able to generate documentation
6570 # and run tests without HIP Python installed
@@ -69,7 +74,7 @@ def __init__(self, device=0, iterations=7, compiler_options=None, observers=None
6974 props = hip .hipDeviceProp_t ()
7075 hip_check (hip .hipGetDeviceProperties (props , device ))
7176
72- self .name = props .name .decode (' utf-8' )
77+ self .name = props .name .decode (" utf-8" )
7378 self .max_threads = props .maxThreadsPerBlock
7479 self .device = device
7580 self .compiler_options = compiler_options or []
@@ -81,7 +86,7 @@ def __init__(self, device=0, iterations=7, compiler_options=None, observers=None
8186 env ["compiler_options" ] = compiler_options
8287 self .env = env
8388
84- # Create stream and events
89+ # Create stream and events
8590 self .stream = hip_check (hip .hipStreamCreate ())
8691 self .start = hip_check (hip .hipEventCreate ())
8792 self .end = hip_check (hip .hipEventCreate ())
@@ -108,40 +113,34 @@ def ready_argument_list(self, arguments):
108113 """
109114 logging .debug ("HipFunction ready_argument_list called" )
110115 prepared_args = []
111-
116+
112117 for arg in arguments :
113118 dtype_str = str (arg .dtype )
114-
119+
115120 # Handle numpy arrays
116121 if isinstance (arg , np .ndarray ):
117122 if dtype_str in dtype_map .keys ():
118123 # Allocate device memory
119124 device_ptr = hip_check (hip .hipMalloc (arg .nbytes ))
120-
125+
121126 # Copy data to device using hipMemcpy
122- hip_check (hip .hipMemcpy (
123- device_ptr ,
124- arg ,
125- arg .nbytes ,
126- hip .hipMemcpyKind .hipMemcpyHostToDevice
127- ))
128-
127+ hip_check (hip .hipMemcpy (device_ptr , arg , arg .nbytes , hip .hipMemcpyKind .hipMemcpyHostToDevice ))
128+
129129 prepared_args .append (device_ptr )
130130 else :
131131 raise TypeError (f"Unknown dtype { dtype_str } for ndarray" )
132-
132+
133133 # Handle numpy scalar types
134134 elif isinstance (arg , np .generic ):
135135 # Convert numpy scalar to corresponding ctypes
136136 ctype_arg = dtype_map [dtype_str ](arg )
137137 prepared_args .append (ctype_arg )
138-
138+
139139 else :
140140 raise ValueError (f"Invalid argument type { type (arg )} , { arg } " )
141141
142142 return prepared_args
143143
144-
145144 def compile (self , kernel_instance ):
146145 """Call the HIP compiler to compile the kernel, return the function.
147146
@@ -159,28 +158,22 @@ def compile(self, kernel_instance):
159158 kernel_name = kernel_instance .name
160159 if 'extern "C"' not in kernel_string :
161160 kernel_string = 'extern "C" {\n ' + kernel_string + "\n }"
162-
161+
163162 # Create program
164- prog = hip_check (hiprtc .hiprtcCreateProgram (
165- kernel_string .encode (),
166- kernel_name .encode (),
167- 0 ,
168- [],
169- []
170- ))
163+ prog = hip_check (hiprtc .hiprtcCreateProgram (kernel_string .encode (), kernel_name .encode (), 0 , [], []))
171164
172165 try :
173166 # Get device properties
174167 props = hip .hipDeviceProp_t ()
175168 hip_check (hip .hipGetDeviceProperties (props , 0 ))
176-
169+
177170 # Setup compilation options
178171 arch = props .gcnArchName
179172 cflags = [b"--offload-arch=" + arch ]
180173 cflags .extend ([opt .encode () if isinstance (opt , str ) else opt for opt in self .compiler_options ])
181174
182175 # Compile program
183- err , = hiprtc .hiprtcCompileProgram (prog , len (cflags ), cflags )
176+ ( err ,) = hiprtc .hiprtcCompileProgram (prog , len (cflags ), cflags )
184177 if err != hiprtc .hiprtcResult .HIPRTC_SUCCESS :
185178 # Get compilation log if there's an error
186179 log_size = hip_check (hiprtc .hiprtcGetProgramLogSize (prog ))
@@ -208,19 +201,19 @@ def compile(self, kernel_instance):
208201 def start_event (self ):
209202 """Records the event that marks the start of a measurement."""
210203 logging .debug ("HipFunction start_event called" )
211-
204+
212205 hip_check (hip .hipEventRecord (self .start , self .stream ))
213206
214207 def stop_event (self ):
215208 """Records the event that marks the end of a measurement."""
216209 logging .debug ("HipFunction stop_event called" )
217-
210+
218211 hip_check (hip .hipEventRecord (self .end , self .stream ))
219212
220213 def kernel_finished (self ):
221214 """Returns True if the kernel has finished, False otherwise."""
222215 logging .debug ("HipFunction kernel_finished called" )
223-
216+
224217 # ROCm HIP returns (hipError_t, bool) for hipEventQuery
225218 status = hip .hipEventQuery (self .end )
226219 if status [0 ] == hip .hipError_t .hipSuccess :
@@ -233,7 +226,7 @@ def kernel_finished(self):
233226 def synchronize (self ):
234227 """Halts execution until device has finished its tasks."""
235228 logging .debug ("HipFunction synchronize called" )
236-
229+
237230 hip_check (hip .hipDeviceSynchronize ())
238231
239232 def run_kernel (self , func , gpu_args , threads , grid , stream = None ):
@@ -242,7 +235,7 @@ def run_kernel(self, func, gpu_args, threads, grid, stream=None):
242235 :param func: A HIP kernel compiled for this specific kernel configuration
243236 :type func: hipFunction_t
244237
245- :param gpu_args: List of arguments to pass to the kernel. Can be DeviceArray
238+ :param gpu_args: List of arguments to pass to the kernel. Can be DeviceArray
246239 objects or ctypes values
247240 :type gpu_args: list
248241
@@ -272,7 +265,7 @@ def run_kernel(self, func, gpu_args, threads, grid, stream=None):
272265 sharedMemBytes = self .smem_size ,
273266 stream = stream ,
274267 kernelParams = None ,
275- extra = tuple (gpu_args )
268+ extra = tuple (gpu_args ),
276269 )
277270 )
278271
@@ -303,12 +296,7 @@ def memcpy_dtoh(self, dest, src):
303296 """
304297 logging .debug ("HipFunction memcpy_dtoh called" )
305298
306- hip_check (hip .hipMemcpy (
307- dest ,
308- src ,
309- dest .nbytes ,
310- hip .hipMemcpyKind .hipMemcpyDeviceToHost
311- ))
299+ hip_check (hip .hipMemcpy (dest , src , dest .nbytes , hip .hipMemcpyKind .hipMemcpyDeviceToHost ))
312300
313301 def memcpy_htod (self , dest , src ):
314302 """Perform a host to device memory copy.
@@ -321,12 +309,7 @@ def memcpy_htod(self, dest, src):
321309 """
322310 logging .debug ("HipFunction memcpy_htod called" )
323311
324- hip_check (hip .hipMemcpy (
325- dest ,
326- src ,
327- src .nbytes ,
328- hip .hipMemcpyKind .hipMemcpyHostToDevice
329- ))
312+ hip_check (hip .hipMemcpy (dest , src , src .nbytes , hip .hipMemcpyKind .hipMemcpyHostToDevice ))
330313
331314 def copy_constant_memory_args (self , cmem_args ):
332315 """Adds constant memory arguments to the most recently compiled module.
@@ -343,18 +326,10 @@ def copy_constant_memory_args(self, cmem_args):
343326 # Iterate over dictionary
344327 for symbol_name , data in cmem_args .items ():
345328 # Get symbol pointer and size using hipModuleGetGlobal
346- dptr , _ = hip_check (hip .hipModuleGetGlobal (
347- self .current_module ,
348- symbol_name .encode ()
349- ))
329+ dptr , _ = hip_check (hip .hipModuleGetGlobal (self .current_module , symbol_name .encode ()))
350330
351331 # Copy data to the global memory location
352- hip_check (hip .hipMemcpy (
353- dptr ,
354- data ,
355- data .nbytes ,
356- hip .hipMemcpyKind .hipMemcpyHostToDevice
357- ))
332+ hip_check (hip .hipMemcpy (dptr , data , data .nbytes , hip .hipMemcpyKind .hipMemcpyHostToDevice ))
358333
359334 def copy_shared_memory_args (self , smem_args ):
360335 """Add shared memory arguments to the kernel."""
0 commit comments