Skip to content

Commit 9f6a8fb

Browse files
Merge pull request #301 from KernelTuner/fix-issue-281
Add support for 16-bit floats in HIP backend
2 parents 70816dc + 751511c commit 9f6a8fb

File tree

1 file changed

+15
-12
lines changed

1 file changed

+15
-12
lines changed

kernel_tuner/backends/hip.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
"bool": ctypes.c_bool,
2020
"int8": ctypes.c_int8,
2121
"int16": ctypes.c_int16,
22-
"float16": ctypes.c_int16,
2322
"int32": ctypes.c_int32,
2423
"int64": ctypes.c_int64,
2524
"uint8": ctypes.c_uint8,
@@ -122,25 +121,29 @@ def ready_argument_list(self, arguments):
122121

123122
# Handle numpy arrays
124123
if isinstance(arg, np.ndarray):
125-
if dtype_str in dtype_map.keys():
126-
# Allocate device memory
127-
device_ptr = hip_check(hip.hipMalloc(arg.nbytes))
124+
# Allocate device memory
125+
device_ptr = hip_check(hip.hipMalloc(arg.nbytes))
128126

129-
# Copy data to device using hipMemcpy
130-
hip_check(hip.hipMemcpy(device_ptr, arg, arg.nbytes, hip.hipMemcpyKind.hipMemcpyHostToDevice))
127+
# Copy data to device using hipMemcpy
128+
hip_check(hip.hipMemcpy(device_ptr, arg, arg.nbytes, hip.hipMemcpyKind.hipMemcpyHostToDevice))
131129

132-
prepared_args.append(device_ptr)
133-
else:
134-
raise TypeError(f"Unknown dtype {dtype_str} for ndarray")
130+
prepared_args.append(device_ptr)
135131

136132
# Handle numpy scalar types
137133
elif isinstance(arg, np.generic):
138134
# Convert numpy scalar to corresponding ctypes
139-
ctype_arg = dtype_map[dtype_str](arg)
140-
prepared_args.append(ctype_arg)
135+
if dtype_str in dtype_map:
136+
ctype_arg = dtype_map[dtype_str](arg)
137+
prepared_args.append(ctype_arg)
138+
# 16-bit float is not supported, view it as uint16
139+
elif dtype_str in ("float16", "bfloat16"):
140+
ctype_arg = ctypes.c_uint16(arg.view(np.uint16))
141+
prepared_args.append(ctype_arg)
142+
else:
143+
raise ValueError(f"Invalid argument type {dtype_str}: {arg}")
141144

142145
else:
143-
raise ValueError(f"Invalid argument type {type(arg)}, {arg}")
146+
raise ValueError(f"Invalid argument type {type(arg)}: {arg}")
144147

145148
return prepared_args
146149

0 commit comments

Comments
 (0)