Skip to content

Commit 43f2cc7

Browse files
author
morelos
committed
[ET-VK][ez] enabling fp64->fp32 converison for vulkan compatibility
# Context We need this conversion so that certain operators can handle floating point values that need to be 64bit. This is predominantly applicable to choose_qparams.tensor where it expects a 64bit output. # Changes Simply adding an additional conversion for float64 to vulkan fp32. Differential Revision: [D77746137](https://our.internmc.facebook.com/intern/diff/D77746137/) [ghstack-poisoned]
1 parent 734e1f8 commit 43f2cc7

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

backends/vulkan/serialization/vulkan_graph_builder.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,10 @@ def get_vk_datatype(torch_dtype: torch.dtype) -> vk_graph_schema.VkDataType:
7979
# Narrowing conversion for index tensor produced by max_poolNd_with_indices.
8080
elif torch_dtype == torch.int64:
8181
return vk_graph_schema.VkDataType.INT32
82+
# Narrowing conversion for float64 (double) to float32 for Vulkan compatibility
83+
elif torch_dtype == torch.float64:
84+
return vk_graph_schema.VkDataType.FLOAT32
85+
8286
else:
8387
raise AssertionError(f"Invalid dtype for vulkan_preprocess ({torch_dtype})")
8488

0 commit comments

Comments
 (0)