2727from executorch .exir .backend .backend_details import BackendDetails , PreprocessResult
2828from executorch .exir .backend .compile_spec_schema import CompileSpec
2929from torch .export .exported_program import ExportedProgram
30- from torch .fx import Graph , Node
30+ from torch .fx import Graph , GraphModule , Node
3131
3232# TOSA backend debug functionality
3333logger = logging .getLogger (__name__ )
@@ -52,13 +52,39 @@ def bfs_mark(start_nodes: List[Node], idx: int, seen: Set[Node]):
5252 # Walk backwards so we touch every producer
5353 q .extend (n .all_input_nodes )
5454
55- out = next (n for n in ep_graph .nodes if n .op == "output" )
55+ out = ep_graph .output_node ()
56+ # First argument of output node is tuple of outputs
57+ output_list = cast (tuple , out .args [0 ])
5658 seen : Set [Node ] = set ()
57- for idx , val in enumerate (out . args [ 0 ] ):
59+ for idx , val in enumerate (output_list ):
5860 bfs_mark ([val ], idx , seen )
5961 return node2external_id
6062
6163
64+ def _sort_outputs (graph_module : GraphModule , node_to_id_map : dict [str , int ]):
65+ def _external_id (n : Node , node_2_id , fallback : int ) -> int :
66+ return node_2_id .get (n .name , fallback )
67+
68+ out_node = graph_module .graph .output_node ()
69+ out_list = cast (tuple , out_node .args [0 ])
70+ _counter = count ()
71+
72+ # sort nodes by the key that is id
73+ def _sort_key (t : Node ) -> int :
74+ return _external_id (t , node_to_id_map , next (_counter ))
75+
76+ orig_ord = tuple (sorted (out_list , key = _sort_key ))
77+
78+ current_order = tuple (out_list )
79+ if orig_ord != current_order :
80+ replacement = list (orig_ord ) if isinstance (out_node .args [0 ], list ) else orig_ord
81+ out_node .args = (replacement ,)
82+ graph_module .graph .lint ()
83+ graph_module .recompile ()
84+
85+ return graph_module
86+
87+
6288def arm_get_first_delegation_tag (graph_module ) -> str :
6389 """Get the first delegation tag from the graph_module or return empty string."""
6490 for node in graph_module .graph .nodes :
@@ -93,9 +119,9 @@ def _preprocess( # noqa: C901
93119 artifact_path = compile_spec .get_intermediate_path ()
94120 tosa_spec = compile_spec .tosa_spec
95121 dump_debug_info = compile_spec .tosa_debug_mode
96-
97- # Assign to every node external id
98- node_2_id = _annotate_external_ids ( edge_program . graph )
122+ debug_hook = None
123+ if dump_debug_info is not None :
124+ debug_hook = DebugHook ( dump_debug_info )
99125
100126 logger .info (f"Converting ExportedProgram to TOSA: { tosa_spec } " )
101127
@@ -116,43 +142,57 @@ def _preprocess( # noqa: C901
116142 f"doesn't match specification { tosa_spec } "
117143 )
118144
145+ TOSABackend ._preprocess_module (
146+ edge_program .graph_module ,
147+ edge_program ,
148+ compile_spec ,
149+ tosa_graph ,
150+ debug_hook ,
151+ )
152+ # Serialize and return the TOSA flatbuffer.
153+ binary = tosa_graph .serialize ()
154+
155+ if artifact_path :
156+ tag = arm_get_first_delegation_tag (edge_program .graph_module )
157+ debug_tosa_dump (
158+ binary ,
159+ artifact_path ,
160+ suffix = "{}" .format (f"_{ tag } " if tag else "" ) + (f"_{ tosa_spec } " ),
161+ )
162+
163+ if debug_hook is not None :
164+ if debug_hook .mode == ArmCompileSpec .DebugMode .JSON :
165+ json_output = debug_hook .serialize ()
166+ with open (f"{ artifact_path } /debug.json" , "w" ) as f :
167+ f .write (json_output )
168+
169+ return PreprocessResult (processed_bytes = binary )
170+
171+ @staticmethod
172+ def _preprocess_module (
173+ graph_module : GraphModule ,
174+ edge_program : ExportedProgram ,
175+ compile_spec : TosaCompileSpec ,
176+ tosa_graph : ts .TosaSerializer ,
177+ debug_hook : DebugHook | None ,
178+ ):
179+ """Convert 'graph_module' to a tosa_graph"""
180+ tosa_spec = compile_spec .tosa_spec
181+ node_to_id_map = _annotate_external_ids (graph_module .graph )
182+ artifact_path = compile_spec .get_intermediate_path ()
183+
119184 # TODO: Fix the need to lazily import this.
120185 from executorch .backends .arm ._passes import ArmPassManager
121186
122187 graph_module = ArmPassManager (tosa_spec ).transform_to_backend_pipeline ( # type: ignore
123- exported_program = edge_program
188+ exported_program = edge_program , graph_module = graph_module
124189 )
125190
126- debug_hook = None
127- if dump_debug_info is not None :
128- debug_hook = DebugHook (dump_debug_info )
129-
130191 # TODO: Fix the need to lazily import this.
131192 from executorch .backends .arm .operators .node_visitor import get_node_visitors
132193
133194 node_visitors = get_node_visitors (edge_program , tosa_spec , debug_hook )
134-
135- # Re-shuffle output nodes to preserve author's order
136- def _external_id (n : Node , node_2_id , fallback : int ) -> int :
137- return node_2_id .get (n .name , fallback )
138-
139- out_node = next (n for n in graph_module .graph .nodes if n .op == "output" )
140- _counter = count ()
141-
142- # sort nodes by the key that is id
143- def _sort_key (t : Node ) -> int :
144- return _external_id (t , node_2_id , next (_counter ))
145-
146- orig_ord = tuple (sorted (out_node .args [0 ], key = _sort_key ))
147-
148- current_order = tuple (out_node .args [0 ])
149- if orig_ord != current_order :
150- replacement = (
151- list (orig_ord ) if isinstance (out_node .args [0 ], list ) else orig_ord
152- )
153- out_node .args = (replacement ,)
154- graph_module .graph .lint ()
155- graph_module .recompile ()
195+ graph_module = _sort_outputs (graph_module , node_to_id_map )
156196
157197 input_count = 0
158198 for node in graph_module .graph .nodes :
@@ -176,25 +216,6 @@ def _sort_key(t: Node) -> int:
176216 debug_fail (node , graph_module , tosa_graph .serialize (), artifact_path )
177217 raise
178218
179- # Serialize and return the TOSA flatbuffer.
180- binary = tosa_graph .serialize ()
181-
182- if artifact_path :
183- tag = arm_get_first_delegation_tag (graph_module )
184- debug_tosa_dump (
185- binary ,
186- artifact_path ,
187- suffix = "{}" .format (f"_{ tag } " if tag else "" ) + (f"_{ tosa_spec } " ),
188- )
189-
190- if debug_hook is not None :
191- if debug_hook .mode == ArmCompileSpec .DebugMode .JSON :
192- json_output = debug_hook .serialize ()
193- with open (f"{ artifact_path } /debug.json" , "w" ) as f :
194- f .write (json_output )
195-
196- return PreprocessResult (processed_bytes = binary )
197-
198219 @staticmethod
199220 def filter_tosa_compile_specs (
200221 compile_spec : ArmCompileSpec ,
0 commit comments