Skip to content

Commit 4d19f2e

Browse files
authored
Merge branch 'master' into searchspace_experiments
2 parents 1513d44 + 486eed9 commit 4d19f2e

File tree

4 files changed

+25
-20
lines changed

4 files changed

+25
-20
lines changed

doc/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,4 +84,4 @@ tzdata==2025.2 ; python_version >= "3.10" and python_version < "4"
8484
urllib3==2.4.0 ; python_version >= "3.10" and python_version < "4"
8585
wcwidth==0.2.13 ; python_version >= "3.10" and python_version < "4"
8686
webencodings==0.5.1 ; python_version >= "3.10" and python_version < "4"
87-
xmltodict==0.14.2 ; python_version >= "3.10" and python_version < "4"
87+
xmltodict==0.14.2 ; python_version >= "3.10" and python_version < "4"

kernel_tuner/backends/hip.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
"bool": ctypes.c_bool,
2020
"int8": ctypes.c_int8,
2121
"int16": ctypes.c_int16,
22-
"float16": ctypes.c_int16,
2322
"int32": ctypes.c_int32,
2423
"int64": ctypes.c_int64,
2524
"uint8": ctypes.c_uint8,
@@ -40,7 +39,9 @@ def hip_check(call_result):
4039
if len(result) == 1:
4140
result = result[0]
4241
if isinstance(err, hip.hipError_t) and err != hip.hipError_t.hipSuccess:
43-
raise RuntimeError(str(err))
42+
_, error_name = hip.hipGetErrorName(err)
43+
_, error_str = hip.hipGetErrorString(err)
44+
raise RuntimeError(f"{error_name}: {error_str}")
4445
return result
4546

4647

@@ -120,25 +121,29 @@ def ready_argument_list(self, arguments):
120121

121122
# Handle numpy arrays
122123
if isinstance(arg, np.ndarray):
123-
if dtype_str in dtype_map.keys():
124-
# Allocate device memory
125-
device_ptr = hip_check(hip.hipMalloc(arg.nbytes))
124+
# Allocate device memory
125+
device_ptr = hip_check(hip.hipMalloc(arg.nbytes))
126126

127-
# Copy data to device using hipMemcpy
128-
hip_check(hip.hipMemcpy(device_ptr, arg, arg.nbytes, hip.hipMemcpyKind.hipMemcpyHostToDevice))
127+
# Copy data to device using hipMemcpy
128+
hip_check(hip.hipMemcpy(device_ptr, arg, arg.nbytes, hip.hipMemcpyKind.hipMemcpyHostToDevice))
129129

130-
prepared_args.append(device_ptr)
131-
else:
132-
raise TypeError(f"Unknown dtype {dtype_str} for ndarray")
130+
prepared_args.append(device_ptr)
133131

134132
# Handle numpy scalar types
135133
elif isinstance(arg, np.generic):
136134
# Convert numpy scalar to corresponding ctypes
137-
ctype_arg = dtype_map[dtype_str](arg)
138-
prepared_args.append(ctype_arg)
135+
if dtype_str in dtype_map:
136+
ctype_arg = dtype_map[dtype_str](arg)
137+
prepared_args.append(ctype_arg)
138+
# 16-bit float is not supported, view it as uint16
139+
elif dtype_str in ("float16", "bfloat16"):
140+
ctype_arg = ctypes.c_uint16(arg.view(np.uint16))
141+
prepared_args.append(ctype_arg)
142+
else:
143+
raise ValueError(f"Invalid argument type {dtype_str}: {arg}")
139144

140145
else:
141-
raise ValueError(f"Invalid argument type {type(arg)}, {arg}")
146+
raise ValueError(f"Invalid argument type {type(arg)}: {arg}")
142147

143148
return prepared_args
144149

kernel_tuner/core.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -506,7 +506,7 @@ def check_kernel_output(
506506
# run the kernel
507507
check = self.run_kernel(func, gpu_args, instance)
508508
if not check:
509-
# runtime failure occured that should be ignored, skip correctness check
509+
# runtime failure occurred that should be ignored, skip correctness check
510510
return
511511

512512
# retrieve gpu results to host memory
@@ -905,7 +905,7 @@ def split_argument_list(argument_list):
905905
match = re.match(regex, arg, re.S)
906906
if not match:
907907
raise ValueError("error parsing templated kernel argument list")
908-
type_list.append(re.sub(r"\s+", " ", match.group(1).strip(), re.S))
908+
type_list.append(re.sub(r"\s+", " ", match.group(1).strip(), flags=re.S))
909909
name_list.append(match.group(2).strip())
910910
return type_list, name_list
911911

kernel_tuner/util.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,7 @@ def check_restriction(restrict, params: dict) -> bool:
289289
return restrict(**selected_params)
290290
# otherwise, raise an error
291291
else:
292-
raise ValueError(f"Unkown restriction type {type(restrict)} ({restrict})")
292+
raise ValueError(f"Unknown restriction type {type(restrict)} ({restrict})")
293293

294294

295295
def check_restrictions(restrictions, params: dict, verbose: bool) -> bool:
@@ -357,7 +357,7 @@ def f_restrict(p):
357357

358358
elif isinstance(restrict, (InSetConstraint, NotInSetConstraint, SomeInSetConstraint, SomeNotInSetConstraint)):
359359
raise NotImplementedError(
360-
f"Restriction of the type {type(restrict)} is explicitely not supported in backwards compatibility mode, because the behaviour is too complex. Please rewrite this constraint to a function to use it with this algorithm."
360+
f"Restriction of the type {type(restrict)} is explicitly not supported in backwards compatibility mode, because the behaviour is too complex. Please rewrite this constraint to a function to use it with this algorithm."
361361
)
362362
else:
363363
raise TypeError(f"Unrecognized restriction {restrict}")
@@ -602,7 +602,7 @@ def get_total_timings(results, env, overhead_time):
602602
total_verification_time += result["verification_time"]
603603
total_benchmark_time += result["benchmark_time"]
604604

605-
# add the seperate times to the environment dict
605+
# add the separate time values to the environment dict
606606
env["total_framework_time"] = total_framework_time
607607
env["total_strategy_time"] = total_strategy_time
608608
env["total_compile_time"] = total_compile_time
@@ -778,7 +778,7 @@ def prepare_kernel_string(kernel_name, kernel_string, params, grid, threads, blo
778778
v = replace_param_occurrences(v, params)
779779

780780
if not k.isidentifier():
781-
raise ValueError("name is not a valid identifier: {k}")
781+
raise ValueError(f"name is not a valid identifier: {k}")
782782

783783
# Escape newline characters
784784
v = str(v)

0 commit comments

Comments
 (0)