77# pyre-strict
88
99import logging
10- from typing import Any , Callable , Dict , final , List , Mapping , Optional , Tuple
10+ from typing import Any , Callable , Dict , final , List , Mapping , Optional , Set , Tuple
1111
1212import executorch .backends .vulkan .utils as utils
1313
1717 get_op_features ,
1818 has_impl ,
1919 OpFeatures ,
20+ OpKey ,
2021 vulkan_supported_ops ,
2122)
2223
@@ -55,11 +56,17 @@ def __init__(
5556 texture_limits : utils .ImageExtents ,
5657 buffer_limit : int ,
5758 require_dynamic_shape : bool = False ,
59+ operator_blocklist : Optional [Set [OpKey ]] = None ,
60+ operator_allowlist : Optional [Set [OpKey ]] = None ,
5861 ) -> None :
5962 super ().__init__ ()
6063 self .texture_limits : utils .ImageExtents = texture_limits
6164 self .buffer_limit = buffer_limit
6265 self .require_dynamic_shapes = require_dynamic_shape
66+ self .operator_blocklist : Set [OpKey ] = (
67+ operator_blocklist if operator_blocklist is not None else set ()
68+ )
69+ self .operator_allowlist = operator_allowlist
6370
6471 def op_node_is_compatible ( # noqa: C901: Function is too complex
6572 self , node : torch .fx .Node , features : Optional [OpFeatures ] = None
@@ -77,6 +84,17 @@ def op_node_is_compatible( # noqa: C901: Function is too complex
7784 assert isinstance (first_arg , torch ._ops .OpOverload )
7885 target = first_arg .name ()
7986
87+ # Operator allow list is only used for torch ops
88+ if (
89+ utils .is_torch_op_node (node )
90+ and (self .operator_allowlist is not None )
91+ and (target not in self .operator_allowlist )
92+ ):
93+ return False , "op is not in allowlist"
94+
95+ if target in self .operator_blocklist :
96+ return False , "op is in blocklist"
97+
8098 # Extract the features for the node's operator, if no override was provided
8199 if features is None :
82100 if not has_impl (target ):
@@ -93,7 +111,7 @@ def op_node_is_compatible( # noqa: C901: Function is too complex
93111 if op_repsets .any_is_empty ():
94112 return (
95113 False ,
96- "No valid representations for a tensor in the operation " ,
114+ f"no valid representations for op { utils . node_io_str ( node ) } " ,
97115 )
98116
99117 return True , "Op is compatible"
@@ -277,6 +295,8 @@ class VulkanPartitioner(Partitioner):
277295 def __init__ (
278296 self ,
279297 compile_options : Optional [Dict [str , Any ]] = None ,
298+ operator_blocklist : Optional [List [OpKey ]] = None ,
299+ operator_allowlist : Optional [List [OpKey ]] = None ,
280300 ) -> None :
281301 self .options : Dict [str , Any ] = {}
282302 if compile_options is not None :
@@ -285,6 +305,18 @@ def __init__(
285305 compile_spec = parse_compile_options (self .options )
286306 self .delegation_spec = DelegationSpec (VulkanBackend .__name__ , compile_spec )
287307
308+ self .operator_blocklist : Set [OpKey ] = set ()
309+ if operator_blocklist is not None :
310+ for entry in operator_blocklist or []:
311+ self .operator_blocklist .add (entry )
312+
313+ self .operator_allowlist : Optional [Set [OpKey ]] = None
314+ if operator_allowlist is not None :
315+ self .operator_allowlist = set ()
316+ for entry in operator_allowlist :
317+ assert self .operator_allowlist is not None
318+ self .operator_allowlist .add (entry )
319+
288320 def ops_to_not_decompose (
289321 self , ep : ExportedProgram
290322 ) -> Tuple [List [torch ._ops .OpOverload ], Optional [Callable [[torch .fx .Node ], bool ]]]:
@@ -308,6 +340,8 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult:
308340 texture_limits ,
309341 buffer_limit ,
310342 require_dynamic_shape = self .options .get ("require_dynamic_shapes" , False ),
343+ operator_blocklist = self .operator_blocklist ,
344+ operator_allowlist = self .operator_allowlist ,
311345 ),
312346 allows_single_node_partition = True ,
313347 )
0 commit comments