@@ -61,6 +61,8 @@ def __init__(
61
61
operator_blocklist : Optional [Set [OpKey ]] = None ,
62
62
operator_allowlist : Optional [Set [OpKey ]] = None ,
63
63
fusable_subgraphs : Optional [List [InternalMatch ]] = None ,
64
+ nn_module_blocklist : Optional [Set [str ]] = None ,
65
+ nn_module_allowlist : Optional [Set [str ]] = None ,
64
66
) -> None :
65
67
super ().__init__ ()
66
68
self .texture_limits : utils .ImageExtents = texture_limits
@@ -78,6 +80,9 @@ def __init__(
78
80
for match in self .fusable_subgraphs :
79
81
self .fusable_nodes .update (match .nodes_map .values ())
80
82
83
+ self .nn_module_blocklist = nn_module_blocklist
84
+ self .nn_module_allowlist = nn_module_allowlist
85
+
81
86
def op_node_is_compatible ( # noqa: C901: Function is too complex
82
87
self , node : torch .fx .Node , features : Optional [OpFeatures ] = None
83
88
) -> Tuple [bool , str ]:
@@ -213,10 +218,26 @@ def is_node_supported(
213
218
r = self ._is_node_supported (node )
214
219
return r
215
220
216
- def _is_node_supported (self , node : torch .fx .Node ) -> bool :
217
- # Check if this node is part of a fusable subgraph
218
- if node .op == "call_function" and node in self .fusable_nodes :
219
- return True
221
+ def _is_node_supported (self , node : torch .fx .Node ) -> bool : # noqa: C901
222
+ if node .op == "call_function" :
223
+ # Apply nn module allowlist and blocklist
224
+ if self .nn_module_allowlist is not None :
225
+ if not utils .node_comes_from_any_nn_module_in_set (
226
+ node , self .nn_module_allowlist
227
+ ):
228
+ self .log_skip (node , "source nn.Module is not in allowlist" )
229
+ return False
230
+
231
+ if self .nn_module_blocklist is not None :
232
+ if utils .node_comes_from_any_nn_module_in_set (
233
+ node , self .nn_module_blocklist
234
+ ):
235
+ self .log_skip (node , "source nn.Module is in blocklist" )
236
+ return False
237
+
238
+ # Check if this node is part of a fusable subgraph
239
+ if node in self .fusable_nodes :
240
+ return True
220
241
221
242
target = node .target
222
243
if node .target == torch .ops .higher_order .auto_functionalized :
@@ -311,6 +332,8 @@ def __init__(
311
332
compile_options : Optional [Dict [str , Any ]] = None ,
312
333
operator_blocklist : Optional [List [OpKey ]] = None ,
313
334
operator_allowlist : Optional [List [OpKey ]] = None ,
335
+ nn_module_blocklist : Optional [List [str ]] = None ,
336
+ nn_module_allowlist : Optional [List [str ]] = None ,
314
337
) -> None :
315
338
self .options : Dict [str , Any ] = {}
316
339
if compile_options is not None :
@@ -331,6 +354,20 @@ def __init__(
331
354
assert self .operator_allowlist is not None
332
355
self .operator_allowlist .add (entry )
333
356
357
+ self .nn_module_blocklist : Optional [Set [str ]] = None
358
+ if nn_module_blocklist is not None :
359
+ self .nn_module_blocklist = set ()
360
+ for entry in nn_module_blocklist or []:
361
+ assert self .nn_module_blocklist is not None
362
+ self .nn_module_blocklist .add (entry )
363
+
364
+ self .nn_module_allowlist : Optional [Set [str ]] = None
365
+ if nn_module_allowlist is not None :
366
+ self .nn_module_allowlist = set ()
367
+ for entry in nn_module_allowlist :
368
+ assert self .nn_module_allowlist is not None
369
+ self .nn_module_allowlist .add (entry )
370
+
334
371
def ops_to_not_decompose (
335
372
self , ep : ExportedProgram
336
373
) -> Tuple [List [torch ._ops .OpOverload ], Optional [Callable [[torch .fx .Node ], bool ]]]:
@@ -362,6 +399,8 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult:
362
399
operator_blocklist = self .operator_blocklist ,
363
400
operator_allowlist = self .operator_allowlist ,
364
401
fusable_subgraphs = fusable_subgraphs ,
402
+ nn_module_blocklist = self .nn_module_blocklist ,
403
+ nn_module_allowlist = self .nn_module_allowlist ,
365
404
),
366
405
allows_single_node_partition = True ,
367
406
)
0 commit comments