Skip to content

Commit aa51149

Browse files
SS-JIAssjia
andauthored
[ET-VK] Add optional blocklist and allowlist to vulkan partitioner to aid debugging (#13326)
Summary: ## Changes * Add `operator_allowlist` and `operator_blocklist` optional arguments to `VulkanPartitioner` * `operator_blocklist` will prevent operators in the block list to be lowered to Vulkan * `operator_allowlist` will only allow operators in the allow list to be lowered to Vulkan * `operator_allowlist` takes precedence over `operator_blocklist` ## Context When debugging models, it is useful to be able to prevent certain operators from being lowered to Vulkan, or to only allow certain operators from being lowered to Vulkan. This can help isolate which ops are causing model output to be incorrect. Test Plan: ## Test Plan Tested this feature locally while debugging example models. Co-authored-by: ssjia <[email protected]>
1 parent c5db75b commit aa51149

File tree

2 files changed

+93
-2
lines changed

2 files changed

+93
-2
lines changed

backends/vulkan/partitioner/vulkan_partitioner.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
# pyre-strict
88

99
import logging
10-
from typing import Any, Callable, Dict, final, List, Mapping, Optional, Tuple
10+
from typing import Any, Callable, Dict, final, List, Mapping, Optional, Set, Tuple
1111

1212
import executorch.backends.vulkan.utils as utils
1313

@@ -17,6 +17,7 @@
1717
get_op_features,
1818
has_impl,
1919
OpFeatures,
20+
OpKey,
2021
vulkan_supported_ops,
2122
)
2223

@@ -55,11 +56,17 @@ def __init__(
5556
texture_limits: utils.ImageExtents,
5657
buffer_limit: int,
5758
require_dynamic_shape: bool = False,
59+
operator_blocklist: Optional[Set[OpKey]] = None,
60+
operator_allowlist: Optional[Set[OpKey]] = None,
5861
) -> None:
5962
super().__init__()
6063
self.texture_limits: utils.ImageExtents = texture_limits
6164
self.buffer_limit = buffer_limit
6265
self.require_dynamic_shapes = require_dynamic_shape
66+
self.operator_blocklist: Set[OpKey] = (
67+
operator_blocklist if operator_blocklist is not None else set()
68+
)
69+
self.operator_allowlist = operator_allowlist
6370

6471
def op_node_is_compatible( # noqa: C901: Function is too complex
6572
self, node: torch.fx.Node, features: Optional[OpFeatures] = None
@@ -77,6 +84,17 @@ def op_node_is_compatible( # noqa: C901: Function is too complex
7784
assert isinstance(first_arg, torch._ops.OpOverload)
7885
target = first_arg.name()
7986

87+
# Operator allow list is only used for torch ops
88+
if (
89+
utils.is_torch_op_node(node)
90+
and (self.operator_allowlist is not None)
91+
and (target not in self.operator_allowlist)
92+
):
93+
return False, "op is not in allowlist"
94+
95+
if target in self.operator_blocklist:
96+
return False, "op is in blocklist"
97+
8098
# Extract the features for the node's operator, if no override was provided
8199
if features is None:
82100
if not has_impl(target):
@@ -93,7 +111,7 @@ def op_node_is_compatible( # noqa: C901: Function is too complex
93111
if op_repsets.any_is_empty():
94112
return (
95113
False,
96-
"No valid representations for a tensor in the operation",
114+
f"no valid representations for op {utils.node_io_str(node)}",
97115
)
98116

99117
return True, "Op is compatible"
@@ -277,6 +295,8 @@ class VulkanPartitioner(Partitioner):
277295
def __init__(
278296
self,
279297
compile_options: Optional[Dict[str, Any]] = None,
298+
operator_blocklist: Optional[List[OpKey]] = None,
299+
operator_allowlist: Optional[List[OpKey]] = None,
280300
) -> None:
281301
self.options: Dict[str, Any] = {}
282302
if compile_options is not None:
@@ -285,6 +305,18 @@ def __init__(
285305
compile_spec = parse_compile_options(self.options)
286306
self.delegation_spec = DelegationSpec(VulkanBackend.__name__, compile_spec)
287307

308+
self.operator_blocklist: Set[OpKey] = set()
309+
if operator_blocklist is not None:
310+
for entry in operator_blocklist or []:
311+
self.operator_blocklist.add(entry)
312+
313+
self.operator_allowlist: Optional[Set[OpKey]] = None
314+
if operator_allowlist is not None:
315+
self.operator_allowlist = set()
316+
for entry in operator_allowlist:
317+
assert self.operator_allowlist is not None
318+
self.operator_allowlist.add(entry)
319+
288320
def ops_to_not_decompose(
289321
self, ep: ExportedProgram
290322
) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]:
@@ -308,6 +340,8 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult:
308340
texture_limits,
309341
buffer_limit,
310342
require_dynamic_shape=self.options.get("require_dynamic_shapes", False),
343+
operator_blocklist=self.operator_blocklist,
344+
operator_allowlist=self.operator_allowlist,
311345
),
312346
allows_single_node_partition=True,
313347
)

backends/vulkan/utils.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
format_target_name,
1919
)
2020

21+
from executorch.exir.dialects.edge._ops import EdgeOpOverload
22+
2123
from executorch.exir.tensor import TensorSpec
2224

2325
from torch._export.utils import is_buffer, is_param
@@ -54,6 +56,18 @@
5456
MaybeNodeList = Union[torch.fx.Node, List[torch.fx.Node], Tuple[torch.fx.Node]]
5557

5658

59+
def is_torch_op_node(node: torch.fx.Node) -> bool:
60+
if node.op != "call_function":
61+
return False
62+
63+
if isinstance(node.target, EdgeOpOverload):
64+
return True
65+
if isinstance(node.target, torch._ops.OpOverload):
66+
return True
67+
68+
return False
69+
70+
5771
def is_dequant_node(node: torch.fx.Node) -> bool:
5872
if node.op != "call_function":
5973
return False
@@ -1033,6 +1047,49 @@ def get_node_repr(node) -> Union[TensorRepr, TensorReprList]:
10331047
##
10341048

10351049

1050+
def get_tensor_val_str(tensor_val: FakeTensor) -> str:
1051+
return f"{tensor_val.dtype}: {tensor_val.shape}"
1052+
1053+
1054+
def get_node_val_str(node: torch.fx.Node) -> str:
1055+
if is_single_tensor_node(node):
1056+
assert isinstance(node.meta["val"], FakeTensor)
1057+
return get_tensor_val_str(node.meta["val"])
1058+
elif is_tensor_collection_node(node):
1059+
assert isinstance(node.meta["val"], (list, tuple))
1060+
return f"[{', '.join(get_tensor_val_str(t) for t in node.meta['val'])}]"
1061+
else:
1062+
return str(node.meta["val"])
1063+
1064+
1065+
def get_arg_node_val_str(arg_node: Any) -> str:
1066+
if isinstance(arg_node, torch.fx.Node):
1067+
return get_node_val_str(arg_node)
1068+
elif isinstance(arg_node, (list, tuple)):
1069+
return f"[{', '.join(get_arg_node_val_str(n) for n in arg_node)}]"
1070+
else:
1071+
return str(arg_node)
1072+
1073+
1074+
def node_io_str(node: torch.fx.Node) -> str:
1075+
target = node.target
1076+
if isinstance(target, EdgeOpOverload):
1077+
assert isinstance(target, EdgeOpOverload)
1078+
target_name = target.__name__
1079+
elif isinstance(target, torch._ops.OpOverload):
1080+
assert isinstance(target, torch._ops.OpOverload)
1081+
target_name = target.name()
1082+
else:
1083+
target_name = str(target)
1084+
1085+
out_str = f"{get_node_val_str(node)} = {target_name}("
1086+
for arg in node.args:
1087+
out_str += get_arg_node_val_str(arg) + ", "
1088+
1089+
out_str += " ...)"
1090+
return out_str
1091+
1092+
10361093
def update_program_state_dict(
10371094
program: ExportedProgram,
10381095
buffer_name: str,

0 commit comments

Comments
 (0)