Skip to content

Commit 6d586f7

Browse files
committed
[ET-VK] Deserialize VkGraph in ET-VK
Add logic to deserialize a VkGraph blob back python object. This allows us to get a implement debugging / visualization directly on the vulkan-exported program. Still extra works need to be done: From the entire bundle, need to extract the specific vulkan delegate first. Differential Revision: [D66443780](https://our.internmc.facebook.com/intern/diff/D66443780/) [ghstack-poisoned]
1 parent fbcc9a1 commit 6d586f7

File tree

2 files changed

+58
-3
lines changed

2 files changed

+58
-3
lines changed

backends/vulkan/serialization/vulkan_graph_serialize.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@
1919
VkBytes,
2020
VkGraph,
2121
)
22-
from executorch.exir._serialize._dataclass import _DataclassEncoder
22+
from executorch.exir._serialize._dataclass import _DataclassEncoder, _json_to_dataclass
2323

24-
from executorch.exir._serialize._flatbuffer import _flatc_compile
24+
from executorch.exir._serialize._flatbuffer import _flatc_compile, _flatc_decompile
2525

2626

2727
def convert_to_flatbuffer(vk_graph: VkGraph) -> bytes:
@@ -40,6 +40,23 @@ def convert_to_flatbuffer(vk_graph: VkGraph) -> bytes:
4040
return output_file.read()
4141

4242

43+
def flatbuffer_to_vk_graph(flatbuffers: bytes) -> VkGraph:
44+
with tempfile.TemporaryDirectory() as d:
45+
schema_path = os.path.join(d, "schema.fbs")
46+
with open(schema_path, "wb") as schema_file:
47+
schema_file.write(pkg_resources.resource_string(__name__, "schema.fbs"))
48+
49+
bin_path = os.path.join(d, "schema.bin")
50+
with open(bin_path, "wb") as bin_file:
51+
bin_file.write(flatbuffers)
52+
53+
_flatc_decompile(d, schema_path, bin_path, ["--raw-binary"])
54+
55+
json_path = os.path.join(d, "schema.json")
56+
with open(json_path, "rb") as output_file:
57+
return _json_to_dataclass(json.load(output_file), VkGraph)
58+
59+
4360
@dataclass
4461
class VulkanDelegateHeader:
4562
# Defines the byte region that each component of the header corresponds to

backends/vulkan/test/test_serialization.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,17 @@
1111

1212
import torch
1313

14-
from executorch.backends.vulkan.serialization.vulkan_graph_schema import VkGraph
14+
from executorch.backends.vulkan.serialization.vulkan_graph_schema import (
15+
IntList,
16+
OperatorCall,
17+
String,
18+
VkGraph,
19+
VkValue,
20+
)
1521

1622
from executorch.backends.vulkan.serialization.vulkan_graph_serialize import (
23+
convert_to_flatbuffer,
24+
flatbuffer_to_vk_graph,
1725
serialize_vulkan_graph,
1826
VulkanDelegateHeader,
1927
)
@@ -93,3 +101,33 @@ def test_serialize_vulkan_binary(self):
93101

94102
tensor_bytes = bytes(array)
95103
self.assertEqual(constant_data_bytes, tensor_bytes)
104+
105+
def test_serialize_deserialize_vkgraph(self):
106+
in_vk_graph = VkGraph(
107+
version="1",
108+
chain=[
109+
OperatorCall(node_id=1, name="foo", args=[1, 2, 3]),
110+
OperatorCall(node_id=2, name="bar", args=[]),
111+
],
112+
values=[
113+
VkValue(
114+
value=String(
115+
string_val="abc",
116+
),
117+
),
118+
VkValue(
119+
value=IntList(
120+
items=[-1, -4, 2],
121+
),
122+
),
123+
],
124+
input_ids=[],
125+
output_ids=[],
126+
constants=[],
127+
shaders=[],
128+
)
129+
130+
bs = convert_to_flatbuffer(in_vk_graph)
131+
out_vk_graph = flatbuffer_to_vk_graph(bs)
132+
133+
self.assertEqual(in_vk_graph, out_vk_graph)

0 commit comments

Comments
 (0)