diff --git a/backends/vulkan/partitioner/TARGETS b/backends/vulkan/partitioner/TARGETS index 40e1f36349a..986d872f730 100644 --- a/backends/vulkan/partitioner/TARGETS +++ b/backends/vulkan/partitioner/TARGETS @@ -22,4 +22,5 @@ runtime.python_library( "//executorch/exir/backend:utils", "//executorch/exir/backend/canonical_partitioners:canonical_partitioner_lib", ], + typing = True, ) diff --git a/backends/vulkan/partitioner/vulkan_partitioner.py b/backends/vulkan/partitioner/vulkan_partitioner.py index 04a1a500b64..06db2a58f12 100644 --- a/backends/vulkan/partitioner/vulkan_partitioner.py +++ b/backends/vulkan/partitioner/vulkan_partitioner.py @@ -61,6 +61,8 @@ def __init__( operator_blocklist: Optional[Set[OpKey]] = None, operator_allowlist: Optional[Set[OpKey]] = None, fusable_subgraphs: Optional[List[InternalMatch]] = None, + nn_module_blocklist: Optional[Set[str]] = None, + nn_module_allowlist: Optional[Set[str]] = None, ) -> None: super().__init__() self.texture_limits: utils.ImageExtents = texture_limits @@ -78,6 +80,9 @@ def __init__( for match in self.fusable_subgraphs: self.fusable_nodes.update(match.nodes_map.values()) + self.nn_module_blocklist = nn_module_blocklist + self.nn_module_allowlist = nn_module_allowlist + def op_node_is_compatible( # noqa: C901: Function is too complex self, node: torch.fx.Node, features: Optional[OpFeatures] = None ) -> Tuple[bool, str]: @@ -213,10 +218,26 @@ def is_node_supported( r = self._is_node_supported(node) return r - def _is_node_supported(self, node: torch.fx.Node) -> bool: - # Check if this node is part of a fusable subgraph - if node.op == "call_function" and node in self.fusable_nodes: - return True + def _is_node_supported(self, node: torch.fx.Node) -> bool: # noqa: C901 + if node.op == "call_function": + # Apply nn module allowlist and blocklist + if self.nn_module_allowlist is not None: + if not utils.node_comes_from_any_nn_module_in_set( + node, self.nn_module_allowlist + ): + self.log_skip(node, "source nn.Module is not in allowlist") + return False + + if self.nn_module_blocklist is not None: + if utils.node_comes_from_any_nn_module_in_set( + node, self.nn_module_blocklist + ): + self.log_skip(node, "source nn.Module is in blocklist") + return False + + # Check if this node is part of a fusable subgraph + if node in self.fusable_nodes: + return True target = node.target if node.target == torch.ops.higher_order.auto_functionalized: @@ -311,6 +332,8 @@ def __init__( compile_options: Optional[Dict[str, Any]] = None, operator_blocklist: Optional[List[OpKey]] = None, operator_allowlist: Optional[List[OpKey]] = None, + nn_module_blocklist: Optional[List[str]] = None, + nn_module_allowlist: Optional[List[str]] = None, ) -> None: self.options: Dict[str, Any] = {} if compile_options is not None: @@ -331,6 +354,20 @@ def __init__( assert self.operator_allowlist is not None self.operator_allowlist.add(entry) + self.nn_module_blocklist: Optional[Set[str]] = None + if nn_module_blocklist is not None: + self.nn_module_blocklist = set() + for entry in nn_module_blocklist or []: + assert self.nn_module_blocklist is not None + self.nn_module_blocklist.add(entry) + + self.nn_module_allowlist: Optional[Set[str]] = None + if nn_module_allowlist is not None: + self.nn_module_allowlist = set() + for entry in nn_module_allowlist: + assert self.nn_module_allowlist is not None + self.nn_module_allowlist.add(entry) + def ops_to_not_decompose( self, ep: ExportedProgram ) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]: @@ -362,6 +399,8 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult: operator_blocklist=self.operator_blocklist, operator_allowlist=self.operator_allowlist, fusable_subgraphs=fusable_subgraphs, + nn_module_blocklist=self.nn_module_blocklist, + nn_module_allowlist=self.nn_module_allowlist, ), allows_single_node_partition=True, ) diff --git a/backends/vulkan/utils.py b/backends/vulkan/utils.py index ee4a8bcc9fc..3b3e27acfbd 100644 --- a/backends/vulkan/utils.py +++ b/backends/vulkan/utils.py @@ -250,6 +250,31 @@ def get_primary_arg_idx(self, node: torch.fx.Node) -> Optional[int]: return primary_arg_idx +def node_comes_from_any_nn_module_in_set( + node, + nn_module_typenames: Set[str], +) -> bool: + if isinstance(node, (list, tuple)): + return all( + node_comes_from_any_nn_module_in_set(n, nn_module_typenames) for n in node + ) + + if not isinstance(node, torch.fx.Node): + return False + + nn_module_stack = node.meta.get("nn_module_stack", None) + if nn_module_stack is None: + return False + + for _, packed in nn_module_stack.items(): + _, typename = packed + for partial_name in nn_module_typenames: + if partial_name in typename: + return True + + return False + + ## ## Memory Layout, Storage Type Determination ##