Skip to content

Commit 3b23603

Browse files
authored
[ET-VK][ez] Allow partitioner to filter by source nn.Module type name (#13693)
Add two new optional arguments to the vulkan partitioner: `nn_module_blocklist`, `nn_module_allowlist`. This allows the partitioner to filter (or only allow) nodes that originate from a specified nn.Module type name. This feature will be useful for debugging as well as provide a more fine grained control over which parts of the model should be executed via Vulkan. Differential Revision: [D80962718](https://our.internmc.facebook.com/intern/diff/D80962718/) [ghstack-poisoned]
1 parent 9777fb3 commit 3b23603

File tree

3 files changed

+69
-4
lines changed

3 files changed

+69
-4
lines changed

backends/vulkan/partitioner/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,4 +22,5 @@ runtime.python_library(
2222
"//executorch/exir/backend:utils",
2323
"//executorch/exir/backend/canonical_partitioners:canonical_partitioner_lib",
2424
],
25+
typing = True,
2526
)

backends/vulkan/partitioner/vulkan_partitioner.py

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ def __init__(
6161
operator_blocklist: Optional[Set[OpKey]] = None,
6262
operator_allowlist: Optional[Set[OpKey]] = None,
6363
fusable_subgraphs: Optional[List[InternalMatch]] = None,
64+
nn_module_blocklist: Optional[Set[str]] = None,
65+
nn_module_allowlist: Optional[Set[str]] = None,
6466
) -> None:
6567
super().__init__()
6668
self.texture_limits: utils.ImageExtents = texture_limits
@@ -78,6 +80,9 @@ def __init__(
7880
for match in self.fusable_subgraphs:
7981
self.fusable_nodes.update(match.nodes_map.values())
8082

83+
self.nn_module_blocklist = nn_module_blocklist
84+
self.nn_module_allowlist = nn_module_allowlist
85+
8186
def op_node_is_compatible( # noqa: C901: Function is too complex
8287
self, node: torch.fx.Node, features: Optional[OpFeatures] = None
8388
) -> Tuple[bool, str]:
@@ -213,10 +218,26 @@ def is_node_supported(
213218
r = self._is_node_supported(node)
214219
return r
215220

216-
def _is_node_supported(self, node: torch.fx.Node) -> bool:
217-
# Check if this node is part of a fusable subgraph
218-
if node.op == "call_function" and node in self.fusable_nodes:
219-
return True
221+
def _is_node_supported(self, node: torch.fx.Node) -> bool: # noqa: C901
222+
if node.op == "call_function":
223+
# Apply nn module allowlist and blocklist
224+
if self.nn_module_allowlist is not None:
225+
if not utils.node_comes_from_any_nn_module_in_set(
226+
node, self.nn_module_allowlist
227+
):
228+
self.log_skip(node, "source nn.Module is not in allowlist")
229+
return False
230+
231+
if self.nn_module_blocklist is not None:
232+
if utils.node_comes_from_any_nn_module_in_set(
233+
node, self.nn_module_blocklist
234+
):
235+
self.log_skip(node, "source nn.Module is in blocklist")
236+
return False
237+
238+
# Check if this node is part of a fusable subgraph
239+
if node in self.fusable_nodes:
240+
return True
220241

221242
target = node.target
222243
if node.target == torch.ops.higher_order.auto_functionalized:
@@ -311,6 +332,8 @@ def __init__(
311332
compile_options: Optional[Dict[str, Any]] = None,
312333
operator_blocklist: Optional[List[OpKey]] = None,
313334
operator_allowlist: Optional[List[OpKey]] = None,
335+
nn_module_blocklist: Optional[List[str]] = None,
336+
nn_module_allowlist: Optional[List[str]] = None,
314337
) -> None:
315338
self.options: Dict[str, Any] = {}
316339
if compile_options is not None:
@@ -331,6 +354,20 @@ def __init__(
331354
assert self.operator_allowlist is not None
332355
self.operator_allowlist.add(entry)
333356

357+
self.nn_module_blocklist: Optional[Set[str]] = None
358+
if nn_module_blocklist is not None:
359+
self.nn_module_blocklist = set()
360+
for entry in nn_module_blocklist or []:
361+
assert self.nn_module_blocklist is not None
362+
self.nn_module_blocklist.add(entry)
363+
364+
self.nn_module_allowlist: Optional[Set[str]] = None
365+
if nn_module_allowlist is not None:
366+
self.nn_module_allowlist = set()
367+
for entry in nn_module_allowlist:
368+
assert self.nn_module_allowlist is not None
369+
self.nn_module_allowlist.add(entry)
370+
334371
def ops_to_not_decompose(
335372
self, ep: ExportedProgram
336373
) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]:
@@ -362,6 +399,8 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult:
362399
operator_blocklist=self.operator_blocklist,
363400
operator_allowlist=self.operator_allowlist,
364401
fusable_subgraphs=fusable_subgraphs,
402+
nn_module_blocklist=self.nn_module_blocklist,
403+
nn_module_allowlist=self.nn_module_allowlist,
365404
),
366405
allows_single_node_partition=True,
367406
)

backends/vulkan/utils.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,31 @@ def get_primary_arg_idx(self, node: torch.fx.Node) -> Optional[int]:
250250
return primary_arg_idx
251251

252252

253+
def node_comes_from_any_nn_module_in_set(
254+
node,
255+
nn_module_typenames: Set[str],
256+
) -> bool:
257+
if isinstance(node, (list, tuple)):
258+
return all(
259+
node_comes_from_any_nn_module_in_set(n, nn_module_typenames) for n in node
260+
)
261+
262+
if not isinstance(node, torch.fx.Node):
263+
return False
264+
265+
nn_module_stack = node.meta.get("nn_module_stack", None)
266+
if nn_module_stack is None:
267+
return False
268+
269+
for _, packed in nn_module_stack.items():
270+
_, typename = packed
271+
for partial_name in nn_module_typenames:
272+
if partial_name in typename:
273+
return True
274+
275+
return False
276+
277+
253278
##
254279
## Memory Layout, Storage Type Determination
255280
##

0 commit comments

Comments
 (0)