Skip to content

Commit bedce91

Browse files
authored
Rewrite Memory Metadata Tagging Pass
Differential Revision: D79116560 Pull Request resolved: #12927
1 parent 02d0c7f commit bedce91

File tree

8 files changed

+1422
-716
lines changed

8 files changed

+1422
-716
lines changed

backends/vulkan/_passes/insert_prepack_nodes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def insert_prepack_nodes(program: ExportedProgram) -> ExportedProgram:
3535

3636
# Mark that this node is going to be represented as a TensorRef type in the
3737
# Vulkan compute graph. This annotation is used in later graph passes.
38-
node.meta["vkdg_tensorref"] = True
38+
node.meta["etvk_tensorref"] = True
3939

4040
# Get the list of node users that do not handle their own prepacking
4141
nodes_to_replace_input = []

backends/vulkan/_passes/remove_local_scalar_dense_ops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def tag_node_if_scalar_tensor(node: torch.fx.Node) -> None:
5252

5353
for user in node.users:
5454
if node_is_local_scalar_dense_chain(user):
55-
node.meta["vkdg_is_scalar_tensor"] = True
55+
node.meta["etvk_is_scalar_tensor"] = True
5656

5757

5858
def remove_local_scalar_dense_chain(graph: torch.fx.Graph, node: torch.fx.Node) -> None:
@@ -74,7 +74,7 @@ def remove_local_scalar_dense_chain(graph: torch.fx.Graph, node: torch.fx.Node)
7474
if replace_node.args[0].meta["val"].numel() == 1:
7575
replace_node = replace_node.args[0]
7676
assert isinstance(replace_node, torch.fx.Node)
77-
assert replace_node.meta.get("vkdg_is_scalar_tensor", True)
77+
assert replace_node.meta.get("etvk_is_scalar_tensor", True)
7878

7979
with graph.inserting_after(node):
8080
node.replace_all_uses_with(replace_node)

backends/vulkan/_passes/tag_memory_meta_pass.py

Lines changed: 371 additions & 236 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)