Skip to content

Commit 749ced7

Browse files
authored
[ET-VK][ez] enabling fp64->fp32 converison for vulkan compatibility
Differential Revision: D77746137 Pull Request resolved: #12201
1 parent 8666a7b commit 749ced7

File tree

5 files changed

+31
-6
lines changed

5 files changed

+31
-6
lines changed

backends/vulkan/runtime/VulkanBackend.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,10 +83,14 @@ vkapi::ScalarType get_scalar_type(const vkgraph::VkDataType& vk_datatype) {
8383
return vkapi::kChar;
8484
case vkgraph::VkDataType::INT32:
8585
return vkapi::kInt;
86+
case vkgraph::VkDataType::INT64:
87+
return vkapi::kLong;
8688
case vkgraph::VkDataType::FLOAT16:
8789
return vkapi::kHalf;
8890
case vkgraph::VkDataType::FLOAT32:
8991
return vkapi::kFloat;
92+
case vkgraph::VkDataType::FLOAT64:
93+
return vkapi::kDouble;
9094
}
9195
}
9296

backends/vulkan/serialization/schema.fbs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ enum VkDataType : byte {
1818
INT32 = 3,
1919
FLOAT16 = 4,
2020
FLOAT32 = 5,
21+
FLOAT64 = 6,
22+
INT64 = 7,
2123
}
2224

2325
// Describes what kind of GPU resource should be used to represent a tensor. The

backends/vulkan/serialization/vulkan_graph_builder.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,11 @@ def __init__(
4545
self,
4646
program: ExportedProgram,
4747
delegate_mapping_builder: DelegateMappingBuilder,
48+
downcast_64_bit: bool = True,
4849
) -> None:
4950
self.program = program
5051
self.delegate_mapping_builder = delegate_mapping_builder
52+
self.downcast_64_bit = downcast_64_bit
5153
self.chain = []
5254
self.values = []
5355
self.input_ids = []
@@ -72,13 +74,14 @@ def get_vk_datatype(torch_dtype: torch.dtype) -> vk_graph_schema.VkDataType:
7274
return vk_graph_schema.VkDataType.INT8
7375
elif torch_dtype == torch.int32:
7476
return vk_graph_schema.VkDataType.INT32
77+
elif torch_dtype == torch.int64:
78+
return vk_graph_schema.VkDataType.INT64
7579
elif torch_dtype == torch.float16:
7680
return vk_graph_schema.VkDataType.FLOAT16
7781
elif torch_dtype == torch.float32:
7882
return vk_graph_schema.VkDataType.FLOAT32
79-
# Narrowing conversion for index tensor produced by max_poolNd_with_indices.
80-
elif torch_dtype == torch.int64:
81-
return vk_graph_schema.VkDataType.INT32
83+
elif torch_dtype == torch.float64:
84+
return vk_graph_schema.VkDataType.FLOAT64
8285
else:
8386
raise AssertionError(f"Invalid dtype for vulkan_preprocess ({torch_dtype})")
8487

@@ -201,11 +204,20 @@ def create_tensor_value(self, spec: TensorSpec, constant_id: int = -1) -> int:
201204
# pyre-ignore[16]
202205
memory_layout = spec.vk_memory_layout
203206

207+
# Apply downcast logic before getting VK datatype
208+
effective_dtype = spec.dtype
209+
if self.downcast_64_bit and spec.dtype == torch.float64:
210+
effective_dtype = torch.float32
211+
elif self.downcast_64_bit and spec.dtype == torch.int64:
212+
effective_dtype = torch.int32
213+
214+
datatype = self.get_vk_datatype(effective_dtype)
215+
204216
new_id = len(self.values)
205217
self.values.append(
206218
vk_graph_schema.VkValue(
207219
value=vk_graph_schema.VkTensor(
208-
datatype=self.get_vk_datatype(spec.dtype),
220+
datatype=datatype,
209221
dims=spec.shape,
210222
constant_id=constant_id,
211223
mem_obj_id=mem_obj_id,

backends/vulkan/serialization/vulkan_graph_schema.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ class VkDataType(IntEnum):
2929
INT32 = 3
3030
FLOAT16 = 4
3131
FLOAT32 = 5
32+
FLOAT64 = 6
33+
INT64 = 7
3234

3335

3436
class VkStorageType(IntEnum):

backends/vulkan/vulkan_preprocess.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,6 @@
6767
# pyre-ignore
6868
def apply_passes(program: ExportedProgram, passes) -> ExportedProgram:
6969
for p in passes:
70-
7170
if issubclass(type(p), ExportPass) or issubclass(type(p), PassBase):
7271
new_gm = program.graph_module
7372
# This is a workaround to allow the memory planning pass to work without
@@ -110,6 +109,9 @@ def parse_compile_spec(compile_specs: List[CompileSpec]) -> Dict[str, Any]:
110109
if spec.key == "skip_tag_memory_metadata":
111110
options[spec.key] = bool.from_bytes(spec.value, byteorder="little")
112111

112+
if spec.key == "downcast_64_bit":
113+
options[spec.key] = bool.from_bytes(spec.value, byteorder="little")
114+
113115
# Unhandled options are ignored
114116

115117
return options
@@ -142,6 +144,7 @@ def preprocess( # noqa: C901
142144
default_memory_layout = compile_options.get(
143145
"memory_layout_override", VkMemoryLayout.TENSOR_WIDTH_PACKED
144146
)
147+
downcast_64_bit = compile_options.get("downcast_64_bit", True)
145148

146149
program = unsafe_remove_auto_functionalized_pass(program)
147150

@@ -213,7 +216,9 @@ def preprocess( # noqa: C901
213216
)
214217

215218
graph_builder = VkGraphBuilder(
216-
program, DelegateMappingBuilder(generated_identifiers=True)
219+
program,
220+
DelegateMappingBuilder(generated_identifiers=True),
221+
downcast_64_bit=downcast_64_bit,
217222
)
218223
vk_graph = graph_builder.build_graph()
219224

0 commit comments

Comments
 (0)