Skip to content

Commit eadf7ea

Browse files
committed
Update
[ghstack-poisoned]
1 parent 97a3aac commit eadf7ea

File tree

2 files changed

+93
-2
lines changed

2 files changed

+93
-2
lines changed

backends/vulkan/partitioner/vulkan_partitioner.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
# pyre-strict
88

99
import logging
10-
from typing import Any, Callable, Dict, final, List, Mapping, Optional, Tuple
10+
from typing import Any, Callable, Dict, final, List, Mapping, Optional, Set, Tuple
1111

1212
import executorch.backends.vulkan.utils as utils
1313

@@ -17,6 +17,7 @@
1717
get_op_features,
1818
has_impl,
1919
OpFeatures,
20+
OpKey,
2021
vulkan_supported_ops,
2122
)
2223

@@ -55,11 +56,17 @@ def __init__(
5556
texture_limits: utils.ImageExtents,
5657
buffer_limit: int,
5758
require_dynamic_shape: bool = False,
59+
operator_blocklist: Optional[Set[OpKey]] = None,
60+
operator_allowlist: Optional[Set[OpKey]] = None,
5861
) -> None:
5962
super().__init__()
6063
self.texture_limits: utils.ImageExtents = texture_limits
6164
self.buffer_limit = buffer_limit
6265
self.require_dynamic_shapes = require_dynamic_shape
66+
self.operator_blocklist: Set[OpKey] = (
67+
operator_blocklist if operator_blocklist is not None else set()
68+
)
69+
self.operator_allowlist = operator_allowlist
6370

6471
def op_node_is_compatible( # noqa: C901: Function is too complex
6572
self, node: torch.fx.Node, features: Optional[OpFeatures] = None
@@ -77,6 +84,17 @@ def op_node_is_compatible( # noqa: C901: Function is too complex
7784
assert isinstance(first_arg, torch._ops.OpOverload)
7885
target = first_arg.name()
7986

87+
# Operator allow list is only used for torch ops
88+
if (
89+
utils.is_torch_op_node(node)
90+
and (self.operator_allowlist is not None)
91+
and (target not in self.operator_allowlist)
92+
):
93+
return False, "op is not in allowlist"
94+
95+
if target in self.operator_blocklist:
96+
return False, "op is in blocklist"
97+
8098
# Extract the features for the node's operator, if no override was provided
8199
if features is None:
82100
if not has_impl(target):
@@ -93,7 +111,7 @@ def op_node_is_compatible( # noqa: C901: Function is too complex
93111
if op_repsets.any_is_empty():
94112
return (
95113
False,
96-
"No valid representations for a tensor in the operation",
114+
f"no valid representations for op {utils.node_io_str(node)}",
97115
)
98116

99117
return True, "Op is compatible"
@@ -277,6 +295,8 @@ class VulkanPartitioner(Partitioner):
277295
def __init__(
278296
self,
279297
compile_options: Optional[Dict[str, Any]] = None,
298+
operator_blocklist: Optional[List[OpKey]] = None,
299+
operator_allowlist: Optional[List[OpKey]] = None,
280300
) -> None:
281301
self.options: Dict[str, Any] = {}
282302
if compile_options is not None:
@@ -285,6 +305,18 @@ def __init__(
285305
compile_spec = parse_compile_options(self.options)
286306
self.delegation_spec = DelegationSpec(VulkanBackend.__name__, compile_spec)
287307

308+
self.operator_blocklist: Set[OpKey] = set()
309+
if operator_blocklist is not None:
310+
for entry in operator_blocklist or []:
311+
self.operator_blocklist.add(entry)
312+
313+
self.operator_allowlist: Optional[Set[OpKey]] = None
314+
if operator_allowlist is not None:
315+
self.operator_allowlist = set()
316+
for entry in operator_allowlist:
317+
assert self.operator_allowlist is not None
318+
self.operator_allowlist.add(entry)
319+
288320
def ops_to_not_decompose(
289321
self, ep: ExportedProgram
290322
) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]:
@@ -308,6 +340,8 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult:
308340
texture_limits,
309341
buffer_limit,
310342
require_dynamic_shape=self.options.get("require_dynamic_shapes", False),
343+
operator_blocklist=self.operator_blocklist,
344+
operator_allowlist=self.operator_allowlist,
311345
),
312346
allows_single_node_partition=True,
313347
)

backends/vulkan/utils.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
format_target_name,
1919
)
2020

21+
from executorch.exir.dialects.edge._ops import EdgeOpOverload
22+
2123
from executorch.exir.tensor import TensorSpec
2224

2325
from torch._export.utils import is_buffer, is_param
@@ -54,6 +56,18 @@
5456
MaybeNodeList = Union[torch.fx.Node, List[torch.fx.Node], Tuple[torch.fx.Node]]
5557

5658

59+
def is_torch_op_node(node: torch.fx.Node) -> bool:
60+
if node.op != "call_function":
61+
return False
62+
63+
if isinstance(node.target, EdgeOpOverload):
64+
return True
65+
if isinstance(node.target, torch._ops.OpOverload):
66+
return True
67+
68+
return False
69+
70+
5771
def is_dequant_node(node: torch.fx.Node) -> bool:
5872
if node.op != "call_function":
5973
return False
@@ -1033,6 +1047,49 @@ def get_node_repr(node) -> Union[TensorRepr, TensorReprList]:
10331047
##
10341048

10351049

1050+
def get_tensor_val_str(tensor_val: FakeTensor) -> str:
1051+
return f"{tensor_val.dtype}: {tensor_val.shape}"
1052+
1053+
1054+
def get_node_val_str(node: torch.fx.Node) -> str:
1055+
if is_single_tensor_node(node):
1056+
assert isinstance(node.meta["val"], FakeTensor)
1057+
return get_tensor_val_str(node.meta["val"])
1058+
elif is_tensor_collection_node(node):
1059+
assert isinstance(node.meta["val"], (list, tuple))
1060+
return f"[{', '.join(get_tensor_val_str(t) for t in node.meta['val'])}]"
1061+
else:
1062+
return str(node.meta["val"])
1063+
1064+
1065+
def get_arg_node_val_str(arg_node: Any) -> str:
1066+
if isinstance(arg_node, torch.fx.Node):
1067+
return get_node_val_str(arg_node)
1068+
elif isinstance(arg_node, (list, tuple)):
1069+
return f"[{', '.join(get_arg_node_val_str(n) for n in arg_node)}]"
1070+
else:
1071+
return str(arg_node)
1072+
1073+
1074+
def node_io_str(node: torch.fx.Node) -> str:
1075+
target = node.target
1076+
if isinstance(target, EdgeOpOverload):
1077+
assert isinstance(target, EdgeOpOverload)
1078+
target_name = target.__name__
1079+
elif isinstance(target, torch._ops.OpOverload):
1080+
assert isinstance(target, torch._ops.OpOverload)
1081+
target_name = target.name()
1082+
else:
1083+
target_name = str(target)
1084+
1085+
out_str = f"{get_node_val_str(node)} = {target_name}("
1086+
for arg in node.args:
1087+
out_str += get_arg_node_val_str(arg) + ", "
1088+
1089+
out_str += " ...)"
1090+
return out_str
1091+
1092+
10361093
def update_program_state_dict(
10371094
program: ExportedProgram,
10381095
buffer_name: str,

0 commit comments

Comments
 (0)