|
20 | 20 | from executorch.backends.arm.operators.node_visitor import get_node_visitors |
21 | 21 | from executorch.backends.arm.operators.op_output import process_output |
22 | 22 | from executorch.backends.arm.operators.op_placeholder import process_placeholder |
| 23 | + |
| 24 | +from executorch.backends.arm.tosa_specification import TosaSpecification |
23 | 25 | from executorch.backends.arm._passes.arm_pass_manager import ( |
24 | 26 | ArmPassManager, |
25 | 27 | ) # usort: skip |
@@ -87,16 +89,23 @@ def ethosu_compile_spec( |
87 | 89 | if extra_flags is not None: |
88 | 90 | self.compiler_flags.append(extra_flags) |
89 | 91 |
|
| 92 | + base_tosa_version = "TOSA-0.80.0+BI" |
| 93 | + if "U55" in config: |
| 94 | + # Add the Ethos-U55 extension marker |
| 95 | + base_tosa_version += "+u55" |
| 96 | + self.tosa_version = TosaSpecification.create_from_string(base_tosa_version) |
| 97 | + |
90 | 98 | return self |
91 | 99 |
|
92 | | - def tosa_compile_spec(self) -> "ArmCompileSpecBuilder": |
| 100 | + def tosa_compile_spec(self, tosa_version: str) -> "ArmCompileSpecBuilder": |
93 | 101 | """ |
94 | 102 | Generate compile spec for TOSA flatbuffer output |
95 | 103 | """ |
96 | 104 | assert ( |
97 | 105 | self.output_format is None |
98 | 106 | ), f"Output format already set: {self.output_format}" |
99 | 107 | self.output_format = "tosa" |
| 108 | + self.tosa_version = TosaSpecification.create_from_string(tosa_version) |
100 | 109 | return self |
101 | 110 |
|
102 | 111 | def dump_intermediate_artifacts_to( |
@@ -130,6 +139,13 @@ def build(self) -> List[CompileSpec]: |
130 | 139 | """ |
131 | 140 | Generate a list of compile spec objects from the builder |
132 | 141 | """ |
| 142 | + assert self.tosa_version |
| 143 | + |
| 144 | + # Always supply a TOSA version |
| 145 | + self.compile_spec = [ |
| 146 | + CompileSpec("tosa_version", str(self.tosa_version).encode()) |
| 147 | + ] |
| 148 | + |
133 | 149 | if self.output_format == "vela": |
134 | 150 | self.compile_spec += [ |
135 | 151 | CompileSpec("output_format", "vela".encode()), |
@@ -211,26 +227,33 @@ def preprocess( # noqa: C901 |
211 | 227 | if not output_format: |
212 | 228 | raise RuntimeError("output format is required") |
213 | 229 |
|
| 230 | + tosa_spec = TosaSpecification.create_from_compilespecs(compile_spec) |
| 231 | + assert ( |
| 232 | + tosa_spec is not None |
| 233 | + ), "TOSA backend needs a TOSA version specified in the CompileSpec!" |
| 234 | + |
214 | 235 | if output_format == "vela" and len(compile_flags) == 0: |
215 | 236 | # Not testing for compile_flags correctness here, just that they are |
216 | 237 | # present. The compiler will give errors if they are not valid. |
217 | 238 | raise RuntimeError("compile flags are required for vela output format") |
218 | 239 |
|
| 240 | + logger.info(f"Converting ExportedProgram to TOSA: {tosa_spec}") |
| 241 | + |
219 | 242 | # Converted output for this subgraph, serializer needs path early as it emits |
220 | 243 | # const data directly. Path created and data written only in debug builds. |
221 | 244 | tosa_graph = ts.TosaSerializer(artifact_path) |
222 | 245 | graph_module = ArmPassManager().transform_to_backend_pipeline( |
223 | 246 | exported_program=edge_program, compile_spec=compile_spec |
224 | 247 | ) |
225 | 248 |
|
226 | | - node_visitors = get_node_visitors(edge_program) |
| 249 | + node_visitors = get_node_visitors(edge_program, tosa_spec) |
227 | 250 |
|
228 | 251 | for node in graph_module.graph.nodes: |
229 | 252 | node = cast(Node, node) |
230 | 253 | if node.op == "call_function": |
231 | | - process_call_function(node, tosa_graph, node_visitors) |
| 254 | + process_call_function(node, tosa_graph, node_visitors, tosa_spec) |
232 | 255 | elif node.op == "placeholder": |
233 | | - process_placeholder(node, tosa_graph, edge_program) |
| 256 | + process_placeholder(node, tosa_graph, edge_program, tosa_spec) |
234 | 257 | elif node.op == "output": |
235 | 258 | process_output(node, tosa_graph) |
236 | 259 | else: |
|
0 commit comments