Skip to content

Commit 751511c

Browse files
committed
Fix issue #281
1 parent 10499af commit 751511c

File tree

1 file changed

+15
-12
lines changed

1 file changed

+15
-12
lines changed

kernel_tuner/backends/hip.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
"bool": ctypes.c_bool,
2020
"int8": ctypes.c_int8,
2121
"int16": ctypes.c_int16,
22-
"float16": ctypes.c_int16,
2322
"int32": ctypes.c_int32,
2423
"int64": ctypes.c_int64,
2524
"uint8": ctypes.c_uint8,
@@ -120,25 +119,29 @@ def ready_argument_list(self, arguments):
120119

121120
# Handle numpy arrays
122121
if isinstance(arg, np.ndarray):
123-
if dtype_str in dtype_map.keys():
124-
# Allocate device memory
125-
device_ptr = hip_check(hip.hipMalloc(arg.nbytes))
122+
# Allocate device memory
123+
device_ptr = hip_check(hip.hipMalloc(arg.nbytes))
126124

127-
# Copy data to device using hipMemcpy
128-
hip_check(hip.hipMemcpy(device_ptr, arg, arg.nbytes, hip.hipMemcpyKind.hipMemcpyHostToDevice))
125+
# Copy data to device using hipMemcpy
126+
hip_check(hip.hipMemcpy(device_ptr, arg, arg.nbytes, hip.hipMemcpyKind.hipMemcpyHostToDevice))
129127

130-
prepared_args.append(device_ptr)
131-
else:
132-
raise TypeError(f"Unknown dtype {dtype_str} for ndarray")
128+
prepared_args.append(device_ptr)
133129

134130
# Handle numpy scalar types
135131
elif isinstance(arg, np.generic):
136132
# Convert numpy scalar to corresponding ctypes
137-
ctype_arg = dtype_map[dtype_str](arg)
138-
prepared_args.append(ctype_arg)
133+
if dtype_str in dtype_map:
134+
ctype_arg = dtype_map[dtype_str](arg)
135+
prepared_args.append(ctype_arg)
136+
# 16-bit float is not supported, view it as uint16
137+
elif dtype_str in ("float16", "bfloat16"):
138+
ctype_arg = ctypes.c_uint16(arg.view(np.uint16))
139+
prepared_args.append(ctype_arg)
140+
else:
141+
raise ValueError(f"Invalid argument type {dtype_str}: {arg}")
139142

140143
else:
141-
raise ValueError(f"Invalid argument type {type(arg)}, {arg}")
144+
raise ValueError(f"Invalid argument type {type(arg)}: {arg}")
142145

143146
return prepared_args
144147

0 commit comments

Comments
 (0)