We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent a0e5031 commit c38c165Copy full SHA for c38c165
sharktank/sharktank/utils/iree.py
@@ -258,7 +258,7 @@ def torch_tensor_to_device_array(
258
tensor: torch.Tensor, device: iree.runtime.HalDevice
259
) -> iree.runtime.DeviceArray:
260
if tensor.dtype in halelementtype_map.keys():
261
- tensor_as_int16 = tensor.view(dtype=torch.int16)
+ tensor_as_int16 = tensor.to(dtype=torch.int16)
262
device_array_as_int16 = iree.runtime.asdevicearray(
263
device, unbox_tensor(tensor_as_int16).to("cpu").detach().numpy()
264
)
0 commit comments