Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions backends/vulkan/runtime/VulkanBackend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,14 @@ vkapi::ScalarType get_scalar_type(const vkgraph::VkDataType& vk_datatype) {
return vkapi::kChar;
case vkgraph::VkDataType::INT32:
return vkapi::kInt;
case vkgraph::VkDataType::INT64:
return vkapi::kLong;
case vkgraph::VkDataType::FLOAT16:
return vkapi::kHalf;
case vkgraph::VkDataType::FLOAT32:
return vkapi::kFloat;
case vkgraph::VkDataType::FLOAT64:
return vkapi::kDouble;
}
}

Expand Down
6 changes: 4 additions & 2 deletions backends/vulkan/serialization/schema.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@ enum VkDataType : byte {
UINT8 = 1,
INT8 = 2,
INT32 = 3,
FLOAT16 = 4,
FLOAT32 = 5,
INT64 = 4,
FLOAT16 = 5,
FLOAT32 = 6,
FLOAT64 = 7,
}

// Describes what kind of GPU resource should be used to represent a tensor. The
Expand Down
20 changes: 16 additions & 4 deletions backends/vulkan/serialization/vulkan_graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,11 @@ def __init__(
self,
program: ExportedProgram,
delegate_mapping_builder: DelegateMappingBuilder,
downcast_64_bit: bool = False,
) -> None:
self.program = program
self.delegate_mapping_builder = delegate_mapping_builder
self.downcast_64_bit = downcast_64_bit
self.chain = []
self.values = []
self.input_ids = []
Expand All @@ -72,13 +74,14 @@ def get_vk_datatype(torch_dtype: torch.dtype) -> vk_graph_schema.VkDataType:
return vk_graph_schema.VkDataType.INT8
elif torch_dtype == torch.int32:
return vk_graph_schema.VkDataType.INT32
elif torch_dtype == torch.int64:
return vk_graph_schema.VkDataType.INT64
elif torch_dtype == torch.float16:
return vk_graph_schema.VkDataType.FLOAT16
elif torch_dtype == torch.float32:
return vk_graph_schema.VkDataType.FLOAT32
# Narrowing conversion for index tensor produced by max_poolNd_with_indices.
elif torch_dtype == torch.int64:
return vk_graph_schema.VkDataType.INT32
elif torch_dtype == torch.float64:
return vk_graph_schema.VkDataType.FLOAT64
else:
raise AssertionError(f"Invalid dtype for vulkan_preprocess ({torch_dtype})")

Expand Down Expand Up @@ -201,11 +204,20 @@ def create_tensor_value(self, spec: TensorSpec, constant_id: int = -1) -> int:
# pyre-ignore[16]
memory_layout = spec.vk_memory_layout

# Apply downcast logic before getting VK datatype
effective_dtype = spec.dtype
if self.downcast_64_bit and spec.dtype == torch.float64:
effective_dtype = torch.float32
elif self.downcast_64_bit and spec.dtype == torch.int64:
effective_dtype = torch.int32

datatype = self.get_vk_datatype(effective_dtype)

new_id = len(self.values)
self.values.append(
vk_graph_schema.VkValue(
value=vk_graph_schema.VkTensor(
datatype=self.get_vk_datatype(spec.dtype),
datatype=datatype,
dims=spec.shape,
constant_id=constant_id,
mem_obj_id=mem_obj_id,
Expand Down
6 changes: 4 additions & 2 deletions backends/vulkan/serialization/vulkan_graph_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,10 @@ class VkDataType(IntEnum):
UINT8 = 1
INT8 = 2
INT32 = 3
FLOAT16 = 4
FLOAT32 = 5
INT64 = 4
FLOAT16 = 5
FLOAT32 = 6
FLOAT64 = 7


class VkStorageType(IntEnum):
Expand Down
9 changes: 7 additions & 2 deletions backends/vulkan/vulkan_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@
# pyre-ignore
def apply_passes(program: ExportedProgram, passes) -> ExportedProgram:
for p in passes:

if issubclass(type(p), ExportPass) or issubclass(type(p), PassBase):
new_gm = program.graph_module
# This is a workaround to allow the memory planning pass to work without
Expand Down Expand Up @@ -110,6 +109,9 @@ def parse_compile_spec(compile_specs: List[CompileSpec]) -> Dict[str, Any]:
if spec.key == "skip_tag_memory_metadata":
options[spec.key] = bool.from_bytes(spec.value, byteorder="little")

if spec.key == "downcast_64_bit":
options[spec.key] = bool.from_bytes(spec.value, byteorder="little")

# Unhandled options are ignored

return options
Expand Down Expand Up @@ -142,6 +144,7 @@ def preprocess( # noqa: C901
default_memory_layout = compile_options.get(
"memory_layout_override", VkMemoryLayout.TENSOR_WIDTH_PACKED
)
downcast_64_bit = compile_options.get("downcast_64_bit", False)

program = unsafe_remove_auto_functionalized_pass(program)

Expand Down Expand Up @@ -213,7 +216,9 @@ def preprocess( # noqa: C901
)

graph_builder = VkGraphBuilder(
program, DelegateMappingBuilder(generated_identifiers=True)
program,
DelegateMappingBuilder(generated_identifiers=True),
downcast_64_bit=downcast_64_bit,
)
vk_graph = graph_builder.build_graph()

Expand Down
Loading