|
19 | 19 | "bool": ctypes.c_bool,
|
20 | 20 | "int8": ctypes.c_int8,
|
21 | 21 | "int16": ctypes.c_int16,
|
22 |
| - "float16": ctypes.c_int16, |
23 | 22 | "int32": ctypes.c_int32,
|
24 | 23 | "int64": ctypes.c_int64,
|
25 | 24 | "uint8": ctypes.c_uint8,
|
@@ -122,25 +121,29 @@ def ready_argument_list(self, arguments):
|
122 | 121 |
|
123 | 122 | # Handle numpy arrays
|
124 | 123 | if isinstance(arg, np.ndarray):
|
125 |
| - if dtype_str in dtype_map.keys(): |
126 |
| - # Allocate device memory |
127 |
| - device_ptr = hip_check(hip.hipMalloc(arg.nbytes)) |
| 124 | + # Allocate device memory |
| 125 | + device_ptr = hip_check(hip.hipMalloc(arg.nbytes)) |
128 | 126 |
|
129 |
| - # Copy data to device using hipMemcpy |
130 |
| - hip_check(hip.hipMemcpy(device_ptr, arg, arg.nbytes, hip.hipMemcpyKind.hipMemcpyHostToDevice)) |
| 127 | + # Copy data to device using hipMemcpy |
| 128 | + hip_check(hip.hipMemcpy(device_ptr, arg, arg.nbytes, hip.hipMemcpyKind.hipMemcpyHostToDevice)) |
131 | 129 |
|
132 |
| - prepared_args.append(device_ptr) |
133 |
| - else: |
134 |
| - raise TypeError(f"Unknown dtype {dtype_str} for ndarray") |
| 130 | + prepared_args.append(device_ptr) |
135 | 131 |
|
136 | 132 | # Handle numpy scalar types
|
137 | 133 | elif isinstance(arg, np.generic):
|
138 | 134 | # Convert numpy scalar to corresponding ctypes
|
139 |
| - ctype_arg = dtype_map[dtype_str](arg) |
140 |
| - prepared_args.append(ctype_arg) |
| 135 | + if dtype_str in dtype_map: |
| 136 | + ctype_arg = dtype_map[dtype_str](arg) |
| 137 | + prepared_args.append(ctype_arg) |
| 138 | + # 16-bit float is not supported, view it as uint16 |
| 139 | + elif dtype_str in ("float16", "bfloat16"): |
| 140 | + ctype_arg = ctypes.c_uint16(arg.view(np.uint16)) |
| 141 | + prepared_args.append(ctype_arg) |
| 142 | + else: |
| 143 | + raise ValueError(f"Invalid argument type {dtype_str}: {arg}") |
141 | 144 |
|
142 | 145 | else:
|
143 |
| - raise ValueError(f"Invalid argument type {type(arg)}, {arg}") |
| 146 | + raise ValueError(f"Invalid argument type {type(arg)}: {arg}") |
144 | 147 |
|
145 | 148 | return prepared_args
|
146 | 149 |
|
|
0 commit comments