Skip to content
4 changes: 4 additions & 0 deletions backends/vulkan/serialization/vulkan_graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,10 @@ def get_vk_datatype(torch_dtype: torch.dtype) -> vk_graph_schema.VkDataType:
# Narrowing conversion for index tensor produced by max_poolNd_with_indices.
elif torch_dtype == torch.int64:
return vk_graph_schema.VkDataType.INT32
# Narrowing conversion for float64 (double) to float32 for Vulkan compatibility
elif torch_dtype == torch.float64:
return vk_graph_schema.VkDataType.FLOAT32

else:
raise AssertionError(f"Invalid dtype for vulkan_preprocess ({torch_dtype})")

Expand Down
Loading