2424 process_placeholder ,
2525)
2626from executorch .backends .arm .tosa .compile_spec import TosaCompileSpec
27- from executorch .backends .arm .tosa .mapping import TOSA_TENSOR_NAME_META
2827from executorch .exir .backend .backend_details import BackendDetails , PreprocessResult
2928from executorch .exir .backend .compile_spec_schema import CompileSpec
30- from executorch .exir .graph_module import get_control_flow_submodules
3129from torch .export .exported_program import ExportedProgram
32- from torch .fx import Graph , GraphModule , Node
33-
30+ from torch .fx import Graph , Node
3431
3532# TOSA backend debug functionality
3633logger = logging .getLogger (__name__ )
@@ -55,39 +52,13 @@ def bfs_mark(start_nodes: List[Node], idx: int, seen: Set[Node]):
5552 # Walk backwards so we touch every producer
5653 q .extend (n .all_input_nodes )
5754
58- out = ep_graph .output_node ()
59- # First argument of output node is tuple of outputs
60- output_list = cast (tuple , out .args [0 ])
55+ out = next (n for n in ep_graph .nodes if n .op == "output" )
6156 seen : Set [Node ] = set ()
62- for idx , val in enumerate (output_list ):
57+ for idx , val in enumerate (out . args [ 0 ] ):
6358 bfs_mark ([val ], idx , seen )
6459 return node2external_id
6560
6661
67- def _sort_outputs (graph_module : GraphModule , node_to_id_map : dict [str , int ]):
68- def _external_id (n : Node , node_2_id , fallback : int ) -> int :
69- return node_2_id .get (n .name , fallback )
70-
71- out_node = graph_module .graph .output_node ()
72- out_list = cast (tuple , out_node .args [0 ])
73- _counter = count ()
74-
75- # sort nodes by the key that is id
76- def _sort_key (t : Node ) -> int :
77- return _external_id (t , node_to_id_map , next (_counter ))
78-
79- orig_ord = tuple (sorted (out_list , key = _sort_key ))
80-
81- current_order = tuple (out_list )
82- if orig_ord != current_order :
83- replacement = list (orig_ord ) if isinstance (out_node .args [0 ], list ) else orig_ord
84- out_node .args = (replacement ,)
85- graph_module .graph .lint ()
86- graph_module .recompile ()
87-
88- return graph_module
89-
90-
9162def arm_get_first_delegation_tag (graph_module ) -> str :
9263 """Get the first delegation tag from the graph_module or return empty string."""
9364 for node in graph_module .graph .nodes :
@@ -122,9 +93,9 @@ def _preprocess( # noqa: C901
12293 artifact_path = compile_spec .get_intermediate_path ()
12394 tosa_spec = compile_spec .tosa_spec
12495 dump_debug_info = compile_spec .tosa_debug_mode
125- debug_hook = None
126- if dump_debug_info is not None :
127- debug_hook = DebugHook ( dump_debug_info )
96+
97+ # Assign to every node external id
98+ node_2_id = _annotate_external_ids ( edge_program . graph )
12899
129100 logger .info (f"Converting ExportedProgram to TOSA: { tosa_spec } " )
130101
@@ -145,66 +116,45 @@ def _preprocess( # noqa: C901
145116 f"doesn't match specification { tosa_spec } "
146117 )
147118
148- TOSABackend ._preprocess_module (
149- edge_program .graph_module ,
150- edge_program ,
151- compile_spec ,
152- tosa_graph ,
153- debug_hook ,
154- )
155- # Serialize and return the TOSA flatbuffer.
156- binary = tosa_graph .serialize ()
157-
158- if artifact_path :
159- tag = arm_get_first_delegation_tag (edge_program .graph_module )
160- debug_tosa_dump (
161- binary ,
162- artifact_path ,
163- suffix = "{}" .format (f"_{ tag } " if tag else "" ) + (f"_{ tosa_spec } " ),
164- )
165-
166- if debug_hook is not None :
167- if debug_hook .mode == ArmCompileSpec .DebugMode .JSON :
168- json_output = debug_hook .serialize ()
169- with open (f"{ artifact_path } /debug.json" , "w" ) as f :
170- f .write (json_output )
171-
172- return PreprocessResult (processed_bytes = binary )
173-
174- @staticmethod
175- def _preprocess_module ( # noqa: C901
176- graph_module : GraphModule ,
177- edge_program : ExportedProgram ,
178- compile_spec : TosaCompileSpec ,
179- tosa_graph : ts .TosaSerializer ,
180- debug_hook : DebugHook | None ,
181- submodule_name : str | None = None ,
182- ):
183- """Convert 'graph_module' to a tosa_graph"""
184- tosa_spec = compile_spec .tosa_spec
185- node_to_id_map = _annotate_external_ids (graph_module .graph )
186- artifact_path = compile_spec .get_intermediate_path ()
187-
188119 # TODO: Fix the need to lazily import this.
189120 from executorch .backends .arm ._passes import ArmPassManager
190121
191122 graph_module = ArmPassManager (tosa_spec ).transform_to_backend_pipeline ( # type: ignore
192- exported_program = edge_program , graph_module = graph_module
123+ exported_program = edge_program
193124 )
194125
126+ debug_hook = None
127+ if dump_debug_info is not None :
128+ debug_hook = DebugHook (dump_debug_info )
129+
195130 # TODO: Fix the need to lazily import this.
196131 from executorch .backends .arm .operators .node_visitor import get_node_visitors
197132
198133 node_visitors = get_node_visitors (edge_program , tosa_spec , debug_hook )
199- graph_module = _sort_outputs (graph_module , node_to_id_map )
200134
201- if submodule_name is not None :
202- tosa_graph .startRegion (submodule_name )
203- tosa_graph .currRegion .addBasicBlock (submodule_name )
204- suffix = f"_{ submodule_name } "
205- for loop_node in graph_module .graph .nodes :
206- loop_node .meta [TOSA_TENSOR_NAME_META ] = suffix
135+ # Re-shuffle output nodes to preserve author's order
136+ def _external_id (n : Node , node_2_id , fallback : int ) -> int :
137+ return node_2_id .get (n .name , fallback )
138+
139+ out_node = next (n for n in graph_module .graph .nodes if n .op == "output" )
140+ _counter = count ()
141+
142+ # sort nodes by the key that is id
143+ def _sort_key (t : Node ) -> int :
144+ return _external_id (t , node_2_id , next (_counter ))
207145
146+ orig_ord = tuple (sorted (out_node .args [0 ], key = _sort_key ))
147+
148+ current_order = tuple (out_node .args [0 ])
149+ if orig_ord != current_order :
150+ replacement = (
151+ list (orig_ord ) if isinstance (out_node .args [0 ], list ) else orig_ord
152+ )
153+ out_node .args = (replacement ,)
154+ graph_module .graph .lint ()
155+ graph_module .recompile ()
156+
157+ input_count = 0
208158 for node in graph_module .graph .nodes :
209159 node = cast (Node , node )
210160 try :
@@ -214,27 +164,37 @@ def _preprocess_module( # noqa: C901
214164 if len (node .users ) == 0 :
215165 continue
216166 process_placeholder (node , tosa_graph , edge_program , tosa_spec )
167+ if node .name in edge_program .graph_signature .user_inputs :
168+ input_count += 1
217169 elif node .op == "output" :
218- process_output (node , tosa_graph , tosa_spec )
170+ process_output (node , tosa_graph )
219171 else :
220172 # This will only happen if an unpartitioned graph is passed without
221173 # any checking of compatibility.
222174 raise RuntimeError (f"{ node .name } is unsupported op { node .op } " )
223175 except Exception :
224- debug_fail (node , graph_module , tosa_graph , artifact_path )
176+ debug_fail (node , graph_module , tosa_graph . serialize () , artifact_path )
225177 raise
226178
227- # Recursively preprocess controlflow submodules .
228- for name , submodule , _ in get_control_flow_submodules ( graph_module ):
229- TOSABackend . _preprocess_module (
230- submodule ,
231- edge_program ,
232- compile_spec ,
233- tosa_graph ,
234- debug_hook ,
235- submodule_name = name ,
179+ # Serialize and return the TOSA flatbuffer .
180+ binary = tosa_graph . serialize ()
181+
182+ if artifact_path :
183+ tag = arm_get_first_delegation_tag ( graph_module )
184+ debug_tosa_dump (
185+ binary ,
186+ artifact_path ,
187+ suffix = "{}" . format ( f"_ { tag } " if tag else "" ) + ( f"_ { tosa_spec } " ) ,
236188 )
237189
190+ if debug_hook is not None :
191+ if debug_hook .mode == ArmCompileSpec .DebugMode .JSON :
192+ json_output = debug_hook .serialize ()
193+ with open (f"{ artifact_path } /debug.json" , "w" ) as f :
194+ f .write (json_output )
195+
196+ return PreprocessResult (processed_bytes = binary )
197+
238198 @staticmethod
239199 def filter_tosa_compile_specs (
240200 compile_spec : ArmCompileSpec ,
0 commit comments