Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 88 additions & 40 deletions backends/vulkan/_passes/tag_memory_meta_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import logging
from copy import deepcopy
from typing import Set
from typing import Any, Set

import executorch.backends.vulkan.utils as utils

Expand Down Expand Up @@ -190,20 +190,24 @@ def propose_node_layout(
return next(iter(valid_layouts))

def should_annotate(self, node) -> bool:
if not isinstance(node, torch.fx.Node):
return False

if not utils.is_tensor_node(node):
return False

# Storage type and memory layout for tensorref will be determined at runtime
# so there's no use in setting those attributes ahead of time.
if node.meta.get("vkdg_tensorref", False):
return False

# Skip annotating output node. The output tensors should be annotated by the
# time the output node is observed.
if node.op == "output":
if isinstance(node, torch.fx.Node):
if not utils.is_tensor_node(node):
return False

# Storage type and memory layout for tensorref will be determined at runtime
# so there's no use in setting those attributes ahead of time.
if node.meta.get("vkdg_tensorref", False):
return False

# Skip annotating output node. The output tensors should be annotated by the
# time the output node is observed.
if node.op == "output":
return False
elif isinstance(node, (list, tuple)):
return all(
isinstance(n, torch.fx.Node) and self.should_annotate(n) for n in node
)
else:
return False

return True
Expand All @@ -215,6 +219,70 @@ def should_delay_annotation(self, node: torch.fx.Node) -> bool:
# time the prepack node is observed.
return node.target == exir_ops.edge.et_vk.prepack.default

def set_or_transition_arg_node(
self,
i: int,
arg: torch.fx.Node,
node: torch.fx.Node,
graph_module: torch.fx.GraphModule,
dirty: bool,
) -> bool:
assert isinstance(arg, torch.fx.Node)

storage = utils.get_node_storage_type(node)
assert storage is not None
layout = utils.get_node_memory_layout(node)
assert layout is not None

arg_storage = utils.get_node_storage_type(arg)
arg_layout = utils.get_node_memory_layout(arg)

if arg_storage is None:
utils.set_node_spec_attr(arg, "vk_storage_type", storage)
arg_storage = storage
if arg_layout is None:
utils.set_node_spec_attr(arg, "vk_memory_layout", layout)
arg_layout = layout

if arg_storage == storage and arg_layout == layout:
return False

if not dirty:
logger.info(
f"[Vulkan Delegate] Inserting transition(s) for {node.format_node()}:"
)

insert_transition_node(graph_module, node, arg, storage, layout)

logger.info(
f" args {i} ({arg}): ({arg_storage}, {arg_layout}) -> ({storage}, {layout})"
)

return True

def set_or_transition_arg(
self,
i: int,
arg: Any,
node: torch.fx.Node,
graph_module: torch.fx.GraphModule,
dirty: bool,
) -> bool:
if isinstance(arg, torch.fx.Node):
return self.set_or_transition_arg_node(i, arg, node, graph_module, dirty)
elif isinstance(arg, (list, tuple)):
need_transition = False
for arg_node in arg:
need_transition = (
self.set_or_transition_arg_node(
i, arg_node, node, graph_module, need_transition
)
or need_transition
)
return need_transition
else:
return False

# noqa
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
for node in graph_module.graph.nodes:
Expand All @@ -226,36 +294,16 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:

set_memory_metadata(node, storage, layout)

inserting_transitions_for_node = False
need_transition = False
for i, arg in enumerate(node.args):
if not self.should_annotate(arg):
continue

assert isinstance(arg, torch.fx.Node)

arg_storage = utils.get_node_storage_type(arg)
arg_layout = utils.get_node_memory_layout(arg)

if arg_storage is None:
utils.set_node_spec_attr(arg, "vk_storage_type", storage)
arg_storage = storage
if arg_layout is None:
utils.set_node_spec_attr(arg, "vk_memory_layout", layout)
arg_layout = layout

if arg_storage == storage and arg_layout == layout:
continue

if not inserting_transitions_for_node:
inserting_transitions_for_node = True
logger.info(
f"[Vulkan Delegate] Inserting transition(s) for {node.format_node()}:"
need_transition = (
self.set_or_transition_arg(
i, arg, node, graph_module, need_transition
)

insert_transition_node(graph_module, node, arg, storage, layout)

logger.info(
f" args {i} ({arg}): ({arg_storage}, {arg_layout}) -> ({storage}, {layout})"
or need_transition
)

return PassResult(graph_module, True)
Loading