|
19 | 19 | "bool": ctypes.c_bool, |
20 | 20 | "int8": ctypes.c_int8, |
21 | 21 | "int16": ctypes.c_int16, |
22 | | - "float16": ctypes.c_int16, |
23 | 22 | "int32": ctypes.c_int32, |
24 | 23 | "int64": ctypes.c_int64, |
25 | 24 | "uint8": ctypes.c_uint8, |
@@ -122,25 +121,29 @@ def ready_argument_list(self, arguments): |
122 | 121 |
|
123 | 122 | # Handle numpy arrays |
124 | 123 | if isinstance(arg, np.ndarray): |
125 | | - if dtype_str in dtype_map.keys(): |
126 | | - # Allocate device memory |
127 | | - device_ptr = hip_check(hip.hipMalloc(arg.nbytes)) |
| 124 | + # Allocate device memory |
| 125 | + device_ptr = hip_check(hip.hipMalloc(arg.nbytes)) |
128 | 126 |
|
129 | | - # Copy data to device using hipMemcpy |
130 | | - hip_check(hip.hipMemcpy(device_ptr, arg, arg.nbytes, hip.hipMemcpyKind.hipMemcpyHostToDevice)) |
| 127 | + # Copy data to device using hipMemcpy |
| 128 | + hip_check(hip.hipMemcpy(device_ptr, arg, arg.nbytes, hip.hipMemcpyKind.hipMemcpyHostToDevice)) |
131 | 129 |
|
132 | | - prepared_args.append(device_ptr) |
133 | | - else: |
134 | | - raise TypeError(f"Unknown dtype {dtype_str} for ndarray") |
| 130 | + prepared_args.append(device_ptr) |
135 | 131 |
|
136 | 132 | # Handle numpy scalar types |
137 | 133 | elif isinstance(arg, np.generic): |
138 | 134 | # Convert numpy scalar to corresponding ctypes |
139 | | - ctype_arg = dtype_map[dtype_str](arg) |
140 | | - prepared_args.append(ctype_arg) |
| 135 | + if dtype_str in dtype_map: |
| 136 | + ctype_arg = dtype_map[dtype_str](arg) |
| 137 | + prepared_args.append(ctype_arg) |
| 138 | + # 16-bit float is not supported, view it as uint16 |
| 139 | + elif dtype_str in ("float16", "bfloat16"): |
| 140 | + ctype_arg = ctypes.c_uint16(arg.view(np.uint16)) |
| 141 | + prepared_args.append(ctype_arg) |
| 142 | + else: |
| 143 | + raise ValueError(f"Invalid argument type {dtype_str}: {arg}") |
141 | 144 |
|
142 | 145 | else: |
143 | | - raise ValueError(f"Invalid argument type {type(arg)}, {arg}") |
| 146 | + raise ValueError(f"Invalid argument type {type(arg)}: {arg}") |
144 | 147 |
|
145 | 148 | return prepared_args |
146 | 149 |
|
|
0 commit comments