66
77# pyre-unsafe
88
9- import operator
10- from typing import Callable , cast , Dict , final , List , Optional , Set , Tuple
9+ from typing import Callable , Dict , final , List , Optional , Tuple
1110
1211import torch
1312from executorch .backends .aoti .aoti_backend import AotiBackend # usort: skip
1817 PartitionResult ,
1918)
2019from executorch .exir .backend .utils import tag_constant_data
21- from executorch .exir .dialects ._ops import ops as exir_ops
2220from torch .export .exported_program import ExportedProgram
23- from torch .fx .passes .infra .partitioner import CapabilityBasedPartitioner
24-
25- from torch .fx .passes .operator_support import OperatorSupportBase
26-
27-
28- class AOTISupportedOperators (OperatorSupportBase ):
29- def is_node_supported (self , submodules , node : torch .fx .Node ) -> bool :
30- # supported = node.op == "call_function" and (
31- # node.target == operator.getitem
32- # or str(node.target._op) not in inductor_fallback_ops
33- # or str(node.target._op) in supported_fallback_operators
34- # )
35-
36- supported = node .op == "call_function"
37-
38- return supported
39-
40- def is_node_supported_custom (self , node : torch .fx .Node ) -> bool :
41- if node .target == exir_ops .edge .aten .mean .dim :
42- keep_dim = node .args [2 ] if len (node .args ) > 2 else False
43- return cast (bool , keep_dim )
44- if node .target == exir_ops .edge .aten .var .correction :
45- keep_dim = node .kwargs .get ("keepdim" , False )
46- return cast (bool , keep_dim )
47- return True
4821
4922
5023@final
5124class AotiPartitioner (Partitioner ):
5225 def __init__ (self , compile_spec : List [CompileSpec ]) -> None :
5326 self .delegation_spec = DelegationSpec (AotiBackend .__name__ , compile_spec )
54- print (self .delegation_spec )
5527
5628 def partition (self , exported_program : ExportedProgram ) -> PartitionResult :
57- # Run the CapabilityBasedPartitioner to return the largest possible
58- # subgraphs containing the nodes with the tags
59- # logger.info("AotiPartitioner::partition")
60- print ("entering partitioner..." )
61-
62- partition_tags = {}
63-
64- capability_partitioner = CapabilityBasedPartitioner (
65- exported_program .graph_module ,
66- AOTISupportedOperators (),
67- allows_single_node_partition = True ,
68- )
69- partition_list = capability_partitioner .propose_partitions ()
70-
71- assert len (partition_list ) == 1 , "Graph break is not supported yet"
72-
73- print (f"graph breaks into { len (partition_list )} parts" )
29+ """
30+ Fully delegate the graph to AOTInductor by tagging all nodes as a single partition.
31+ """
7432
75- for partition in partition_list :
76- for node in partition .nodes :
77- tag = f"tag{ partition .id } "
78- node .meta ["delegation_tag" ] = tag
79- partition_tags [tag ] = self .delegation_spec
33+ partition_tags : Dict [str , DelegationSpec ] = {}
34+ for node in exported_program .graph .nodes :
35+ if node .op != "call_function" :
36+ continue
37+ tag = f"tag0"
38+ node .meta ["delegation_tag" ] = tag
39+ partition_tags [tag ] = self .delegation_spec
8040
8141 tag_constant_data (exported_program )
8242
@@ -89,15 +49,13 @@ def ops_to_not_decompose(
8949 ) -> Tuple [List [torch ._ops .OpOverload ], Optional [Callable [[torch .fx .Node ], bool ]]]:
9050 """
9151 Return a list of operations that should not be decomposed and let the AOT compiler handle them.
52+ Currently we skip decomposing all ops and let the AOT compiler handle them.
9253 """
9354 do_not_decompose = set ()
94- op_support = AOTISupportedOperators ()
9555
9656 for node in ep .graph .nodes :
97- if (
98- node .op == "call_function"
99- and isinstance (node .target , torch ._ops .OpOverload )
100- and op_support .is_node_supported (None , node )
57+ if node .op == "call_function" and isinstance (
58+ node .target , torch ._ops .OpOverload
10159 ):
10260 do_not_decompose .add (node .target )
10361 return list (do_not_decompose ), None
0 commit comments