diff --git a/backends/vulkan/serialization/vulkan_graph_serialize.py b/backends/vulkan/serialization/vulkan_graph_serialize.py index 37785f47521..c97ea69a435 100644 --- a/backends/vulkan/serialization/vulkan_graph_serialize.py +++ b/backends/vulkan/serialization/vulkan_graph_serialize.py @@ -1,6 +1,8 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # +# pyre-strict +# # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -19,9 +21,9 @@ VkBytes, VkGraph, ) -from executorch.exir._serialize._dataclass import _DataclassEncoder +from executorch.exir._serialize._dataclass import _DataclassEncoder, _json_to_dataclass -from executorch.exir._serialize._flatbuffer import _flatc_compile +from executorch.exir._serialize._flatbuffer import _flatc_compile, _flatc_decompile def convert_to_flatbuffer(vk_graph: VkGraph) -> bytes: @@ -40,6 +42,25 @@ def convert_to_flatbuffer(vk_graph: VkGraph) -> bytes: return output_file.read() +def flatbuffer_to_vk_graph(flatbuffers: bytes) -> VkGraph: + # Following similar (de)serialization logic on other backends: + # https://github.com/pytorch/executorch/blob/main/backends/qualcomm/serialization/qc_schema_serialize.py#L33 + with tempfile.TemporaryDirectory() as d: + schema_path = os.path.join(d, "schema.fbs") + with open(schema_path, "wb") as schema_file: + schema_file.write(pkg_resources.resource_string(__name__, "schema.fbs")) + + bin_path = os.path.join(d, "schema.bin") + with open(bin_path, "wb") as bin_file: + bin_file.write(flatbuffers) + + _flatc_decompile(d, schema_path, bin_path, ["--raw-binary"]) + + json_path = os.path.join(d, "schema.json") + with open(json_path, "rb") as output_file: + return _json_to_dataclass(json.load(output_file), VkGraph) + + @dataclass class VulkanDelegateHeader: # Defines the byte region that each component of the header corresponds to diff --git a/backends/vulkan/test/test_serialization.py b/backends/vulkan/test/test_serialization.py index eb112d7b12b..c373f5216d2 100644 --- a/backends/vulkan/test/test_serialization.py +++ b/backends/vulkan/test/test_serialization.py @@ -1,6 +1,8 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # +# pyre-strict +# # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -11,9 +13,17 @@ import torch -from executorch.backends.vulkan.serialization.vulkan_graph_schema import VkGraph +from executorch.backends.vulkan.serialization.vulkan_graph_schema import ( + IntList, + OperatorCall, + String, + VkGraph, + VkValue, +) from executorch.backends.vulkan.serialization.vulkan_graph_serialize import ( + convert_to_flatbuffer, + flatbuffer_to_vk_graph, serialize_vulkan_graph, VulkanDelegateHeader, ) @@ -36,7 +46,7 @@ def _generate_random_const_tensors(self, num_tensors: int) -> List[torch.Tensor] return tensors - def test_serialize_vulkan_binary(self): + def test_serialize_vulkan_binary(self) -> None: vk_graph = VkGraph( version="0", chain=[], @@ -93,3 +103,33 @@ def test_serialize_vulkan_binary(self): tensor_bytes = bytes(array) self.assertEqual(constant_data_bytes, tensor_bytes) + + def test_serialize_deserialize_vkgraph(self) -> None: + in_vk_graph = VkGraph( + version="1", + chain=[ + OperatorCall(node_id=1, name="foo", args=[1, 2, 3]), + OperatorCall(node_id=2, name="bar", args=[]), + ], + values=[ + VkValue( + value=String( + string_val="abc", + ), + ), + VkValue( + value=IntList( + items=[-1, -4, 2], + ), + ), + ], + input_ids=[], + output_ids=[], + constants=[], + shaders=[], + ) + + bs = convert_to_flatbuffer(in_vk_graph) + out_vk_graph = flatbuffer_to_vk_graph(bs) + + self.assertEqual(in_vk_graph, out_vk_graph)