@@ -52,6 +52,7 @@ def __init__(self):
5252 self .permute_nhwc = False
5353 self .quantize_io = False
5454 self .tosa_version = None
55+ self .input_order = None
5556
5657 def ethosu_compile_spec (
5758 self ,
@@ -89,7 +90,7 @@ def ethosu_compile_spec(
8990 self .compiler_flags .append (extra_flags )
9091
9192 base_tosa_version = "TOSA-0.80.0+BI"
92- if "U55 " in config :
93+ if "u55 " in config :
9394 # Add the Ethos-U55 extension marker
9495 base_tosa_version += "+u55"
9596 self .tosa_version = TosaSpecification .create_from_string (base_tosa_version )
@@ -134,6 +135,14 @@ def set_quantize_io(self, quantize_io: bool = False) -> "ArmCompileSpecBuilder":
134135 self .quantize_io = quantize_io
135136 return self
136137
138+ def set_input_order (self , input_order : str = None ) -> "ArmCompileSpecBuilder" :
139+ """
140+ Reorder the inputs coming in. This may be required when inputs > 1.
141+ And while using the U55/U85 CompileSpec.
142+ """
143+ self .input_order = input_order
144+ return self
145+
137146 def build (self ) -> List [CompileSpec ]:
138147 """
139148 Generate a list of compile spec objects from the builder
@@ -163,6 +172,13 @@ def build(self) -> List[CompileSpec]:
163172 CompileSpec ("permute_memory_format" , "nhwc" .encode ())
164173 )
165174
175+ if self .input_order :
176+ self .compile_spec .append (
177+ CompileSpec (
178+ "input_order" , " " .join (map (str , self .input_order )).encode ()
179+ )
180+ )
181+
166182 if self .quantize_io :
167183 self .compile_spec .append (CompileSpec ("quantize_io" , "True" .encode ()))
168184
@@ -214,13 +230,16 @@ def preprocess( # noqa: C901
214230 artifact_path = None
215231 output_format = ""
216232 compile_flags = []
233+ input_order = []
217234 for spec in compile_spec :
218235 if spec .key == "debug_artifact_path" :
219236 artifact_path = spec .value .decode ()
220237 if spec .key == "output_format" :
221238 output_format = spec .value .decode ()
222239 if spec .key == "compile_flags" :
223240 compile_flags .append (spec .value .decode ())
241+ if spec .key == "input_order" :
242+ input_order = list (map (int , spec .value .decode ().split ("," )))
224243
225244 # Check that the output format is set in the compile spec
226245 if not output_format :
@@ -246,19 +265,27 @@ def preprocess( # noqa: C901
246265 )
247266
248267 node_visitors = get_node_visitors (edge_program , tosa_spec )
249-
268+ input_count = 0
250269 for node in graph_module .graph .nodes :
251270 if node .op == "call_function" :
252271 process_call_function (node , tosa_graph , node_visitors , tosa_spec )
253272 elif node .op == "placeholder" :
254273 process_placeholder (node , tosa_graph , edge_program , tosa_spec )
274+ if node .name in edge_program .graph_signature .user_inputs :
275+ input_count += 1
255276 elif node .op == "output" :
256277 process_output (node , tosa_graph )
257278 else :
258279 # This will only happen if an unpartitioned graph is passed without
259280 # any checking of compatibility.
260281 dbg_fail (node , tosa_graph , artifact_path )
261282
283+ if len (input_order ) > 0 :
284+ if input_count != len (input_order ):
285+ raise RuntimeError (
286+ "The rank of the input order is not equal to amount of input tensors"
287+ )
288+
262289 # TODO: It would be awesome if this dump could somehow be done on top level and not here.
263290 # Problem is that the desc.json has to be created on the tosa_graph object, which we can't
264291 # access from top level.
@@ -275,7 +302,7 @@ def preprocess( # noqa: C901
275302 # preprocess and some consume TOSA fb directly.
276303 if output_format == "vela" :
277304 # Emit vela_bin_stream format
278- binary = vela_compile (tosa_graph , compile_flags )
305+ binary = vela_compile (tosa_graph , compile_flags , input_order )
279306 elif output_format == "tosa" :
280307 # Emit TOSA flatbuffer
281308 binary = bytes (tosa_graph .serialize ())
0 commit comments