Skip to content

Commit c38c165

Browse files
committed
fix for fp8 kv cache
1 parent a0e5031 commit c38c165

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

sharktank/sharktank/utils/iree.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,7 @@ def torch_tensor_to_device_array(
258258
tensor: torch.Tensor, device: iree.runtime.HalDevice
259259
) -> iree.runtime.DeviceArray:
260260
if tensor.dtype in halelementtype_map.keys():
261-
tensor_as_int16 = tensor.view(dtype=torch.int16)
261+
tensor_as_int16 = tensor.to(dtype=torch.int16)
262262
device_array_as_int16 = iree.runtime.asdevicearray(
263263
device, unbox_tensor(tensor_as_int16).to("cpu").detach().numpy()
264264
)

0 commit comments

Comments
 (0)