@@ -45,9 +45,11 @@ def __init__(
4545 self ,
4646 program : ExportedProgram ,
4747 delegate_mapping_builder : DelegateMappingBuilder ,
48+ downcast_64_bit : bool = True ,
4849 ) -> None :
4950 self .program = program
5051 self .delegate_mapping_builder = delegate_mapping_builder
52+ self .downcast_64_bit = downcast_64_bit
5153 self .chain = []
5254 self .values = []
5355 self .input_ids = []
@@ -72,13 +74,14 @@ def get_vk_datatype(torch_dtype: torch.dtype) -> vk_graph_schema.VkDataType:
7274 return vk_graph_schema .VkDataType .INT8
7375 elif torch_dtype == torch .int32 :
7476 return vk_graph_schema .VkDataType .INT32
77+ elif torch_dtype == torch .int64 :
78+ return vk_graph_schema .VkDataType .INT64
7579 elif torch_dtype == torch .float16 :
7680 return vk_graph_schema .VkDataType .FLOAT16
7781 elif torch_dtype == torch .float32 :
7882 return vk_graph_schema .VkDataType .FLOAT32
79- # Narrowing conversion for index tensor produced by max_poolNd_with_indices.
80- elif torch_dtype == torch .int64 :
81- return vk_graph_schema .VkDataType .INT32
83+ elif torch_dtype == torch .float64 :
84+ return vk_graph_schema .VkDataType .FLOAT64
8285 else :
8386 raise AssertionError (f"Invalid dtype for vulkan_preprocess ({ torch_dtype } )" )
8487
@@ -201,11 +204,20 @@ def create_tensor_value(self, spec: TensorSpec, constant_id: int = -1) -> int:
201204 # pyre-ignore[16]
202205 memory_layout = spec .vk_memory_layout
203206
207+ # Apply downcast logic before getting VK datatype
208+ effective_dtype = spec .dtype
209+ if self .downcast_64_bit and spec .dtype == torch .float64 :
210+ effective_dtype = torch .float32
211+ elif self .downcast_64_bit and spec .dtype == torch .int64 :
212+ effective_dtype = torch .int32
213+
214+ datatype = self .get_vk_datatype (effective_dtype )
215+
204216 new_id = len (self .values )
205217 self .values .append (
206218 vk_graph_schema .VkValue (
207219 value = vk_graph_schema .VkTensor (
208- datatype = self . get_vk_datatype ( spec . dtype ) ,
220+ datatype = datatype ,
209221 dims = spec .shape ,
210222 constant_id = constant_id ,
211223 mem_obj_id = mem_obj_id ,
0 commit comments