Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/vulkan/partitioner/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,5 @@ runtime.python_library(
"//executorch/exir/backend:utils",
"//executorch/exir/backend/canonical_partitioners:canonical_partitioner_lib",
],
typing = True,
)
47 changes: 43 additions & 4 deletions backends/vulkan/partitioner/vulkan_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ def __init__(
operator_blocklist: Optional[Set[OpKey]] = None,
operator_allowlist: Optional[Set[OpKey]] = None,
fusable_subgraphs: Optional[List[InternalMatch]] = None,
nn_module_blocklist: Optional[Set[str]] = None,
nn_module_allowlist: Optional[Set[str]] = None,
) -> None:
super().__init__()
self.texture_limits: utils.ImageExtents = texture_limits
Expand All @@ -78,6 +80,9 @@ def __init__(
for match in self.fusable_subgraphs:
self.fusable_nodes.update(match.nodes_map.values())

self.nn_module_blocklist = nn_module_blocklist
self.nn_module_allowlist = nn_module_allowlist

def op_node_is_compatible( # noqa: C901: Function is too complex
self, node: torch.fx.Node, features: Optional[OpFeatures] = None
) -> Tuple[bool, str]:
Expand Down Expand Up @@ -213,10 +218,26 @@ def is_node_supported(
r = self._is_node_supported(node)
return r

def _is_node_supported(self, node: torch.fx.Node) -> bool:
# Check if this node is part of a fusable subgraph
if node.op == "call_function" and node in self.fusable_nodes:
return True
def _is_node_supported(self, node: torch.fx.Node) -> bool: # noqa: C901
if node.op == "call_function":
# Apply nn module allowlist and blocklist
if self.nn_module_allowlist is not None:
if not utils.node_comes_from_any_nn_module_in_set(
node, self.nn_module_allowlist
):
self.log_skip(node, "source nn.Module is not in allowlist")
return False

if self.nn_module_blocklist is not None:
if utils.node_comes_from_any_nn_module_in_set(
node, self.nn_module_blocklist
):
self.log_skip(node, "source nn.Module is in blocklist")
return False

# Check if this node is part of a fusable subgraph
if node in self.fusable_nodes:
return True

target = node.target
if node.target == torch.ops.higher_order.auto_functionalized:
Expand Down Expand Up @@ -311,6 +332,8 @@ def __init__(
compile_options: Optional[Dict[str, Any]] = None,
operator_blocklist: Optional[List[OpKey]] = None,
operator_allowlist: Optional[List[OpKey]] = None,
nn_module_blocklist: Optional[List[str]] = None,
nn_module_allowlist: Optional[List[str]] = None,
) -> None:
self.options: Dict[str, Any] = {}
if compile_options is not None:
Expand All @@ -331,6 +354,20 @@ def __init__(
assert self.operator_allowlist is not None
self.operator_allowlist.add(entry)

self.nn_module_blocklist: Optional[Set[str]] = None
if nn_module_blocklist is not None:
self.nn_module_blocklist = set()
for entry in nn_module_blocklist or []:
assert self.nn_module_blocklist is not None
self.nn_module_blocklist.add(entry)

self.nn_module_allowlist: Optional[Set[str]] = None
if nn_module_allowlist is not None:
self.nn_module_allowlist = set()
for entry in nn_module_allowlist:
assert self.nn_module_allowlist is not None
self.nn_module_allowlist.add(entry)

def ops_to_not_decompose(
self, ep: ExportedProgram
) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]:
Expand Down Expand Up @@ -362,6 +399,8 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult:
operator_blocklist=self.operator_blocklist,
operator_allowlist=self.operator_allowlist,
fusable_subgraphs=fusable_subgraphs,
nn_module_blocklist=self.nn_module_blocklist,
nn_module_allowlist=self.nn_module_allowlist,
),
allows_single_node_partition=True,
)
Expand Down
25 changes: 25 additions & 0 deletions backends/vulkan/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,31 @@ def get_primary_arg_idx(self, node: torch.fx.Node) -> Optional[int]:
return primary_arg_idx


def node_comes_from_any_nn_module_in_set(
node,
nn_module_typenames: Set[str],
) -> bool:
if isinstance(node, (list, tuple)):
return all(
node_comes_from_any_nn_module_in_set(n, nn_module_typenames) for n in node
)

if not isinstance(node, torch.fx.Node):
return False

nn_module_stack = node.meta.get("nn_module_stack", None)
if nn_module_stack is None:
return False

for _, packed in nn_module_stack.items():
_, typename = packed
for partial_name in nn_module_typenames:
if partial_name in typename:
return True

return False


##
## Memory Layout, Storage Type Determination
##
Expand Down
Loading