24
24
process_output ,
25
25
process_placeholder ,
26
26
)
27
- from executorch .backends .arm .tosa .specification import get_tosa_spec
27
+ from executorch .backends .arm .tosa .compile_spec import TosaCompileSpec
28
28
from executorch .exir .backend .backend_details import BackendDetails , PreprocessResult
29
29
from executorch .exir .backend .compile_spec_schema import CompileSpec
30
30
from torch .export .exported_program import ExportedProgram
@@ -80,38 +80,24 @@ class TOSABackend(BackendDetails):
80
80
"""
81
81
82
82
@staticmethod
83
- def preprocess ( # noqa: C901
83
+ def preprocess (edge_program : ExportedProgram , compile_specs : List [CompileSpec ]):
84
+ return TOSABackend ._preprocess (
85
+ edge_program , TosaCompileSpec .from_list (compile_specs )
86
+ )
87
+
88
+ @staticmethod
89
+ def _preprocess ( # noqa: C901
84
90
edge_program : ExportedProgram ,
85
- compile_spec : List [ CompileSpec ] ,
91
+ compile_spec : TosaCompileSpec ,
86
92
) -> PreprocessResult :
87
93
# if a debug/test build capture output files from TOSA stage
88
- artifact_path = None
89
- output_format = ""
90
- compile_flags = []
91
- dump_debug_info = None
92
- for spec in compile_spec :
93
- if spec .key == "debug_artifact_path" :
94
- artifact_path = spec .value .decode ()
95
- if spec .key == "output_format" :
96
- output_format = spec .value .decode ()
97
- if spec .key == "compile_flags" :
98
- compile_flags .append (spec .value .decode ())
99
- if spec .key == "dump_debug_info" :
100
- dump_debug_info = spec .value .decode ()
101
-
102
- # Check that the output format is set correctly in the compile spec
103
- if output_format != "tosa" :
104
- raise ValueError (f'Invalid output format { output_format } , must be "tosa"' )
94
+ artifact_path = compile_spec .get_intermediate_path ()
95
+ tosa_spec = compile_spec .tosa_spec
96
+ dump_debug_info = compile_spec .tosa_debug_mode
105
97
106
98
# Assign to every node external id
107
99
node_2_id = _annotate_external_ids (edge_program .graph )
108
100
109
- tosa_spec = get_tosa_spec (compile_spec )
110
- if tosa_spec is None :
111
- raise ValueError (
112
- "TOSA backend needs a TOSA version specified in the CompileSpec"
113
- )
114
-
115
101
logger .info (f"Converting ExportedProgram to TOSA: { tosa_spec } " )
116
102
117
103
# Converted output for this subgraph, serializer needs path early as it emits
@@ -132,7 +118,7 @@ def preprocess( # noqa: C901
132
118
133
119
debug_hook = None
134
120
if dump_debug_info is not None :
135
- debug_hook = DebugHook (ArmCompileSpec . DebugMode [ dump_debug_info ] )
121
+ debug_hook = DebugHook (dump_debug_info )
136
122
137
123
# TODO: Fix the need to lazily import this.
138
124
from executorch .backends .arm .operators .node_visitor import get_node_visitors
@@ -204,8 +190,8 @@ def _sort_key(t: Node) -> int:
204
190
205
191
@staticmethod
206
192
def filter_tosa_compile_specs (
207
- compile_spec : List [ CompileSpec ] ,
208
- ) -> List [ CompileSpec ] :
193
+ compile_spec : ArmCompileSpec ,
194
+ ) -> TosaCompileSpec :
209
195
"""
210
196
Filter out the CompileSpec elements relevant for the TOSA backend.
211
197
This is needed to compose a backend targetting hardware IP with the
@@ -214,17 +200,9 @@ def filter_tosa_compile_specs(
214
200
flatbuffer can then be consumed by the backend targetting specific
215
201
hardware.
216
202
"""
217
- tosa_compile_spec = []
218
- tosa_compile_spec .append (CompileSpec ("output_format" , "tosa" .encode ()))
219
-
220
- # Copy everything that's TOSA generic
221
- tosa_backend_compile_spec_keys = [
222
- "tosa_spec" ,
223
- "debug_artifact_path" ,
224
- ]
225
203
226
- for spec in compile_spec :
227
- if spec . key in tosa_backend_compile_spec_keys :
228
- tosa_compile_spec . append ( CompileSpec ( spec . key , spec . value ) )
229
-
230
- return tosa_compile_spec
204
+ new_compile_spec = TosaCompileSpec . __new__ ( TosaCompileSpec )
205
+ new_compile_spec . _set_compile_specs (
206
+ compile_spec . tosa_spec , [], compile_spec . get_intermediate_path ( )
207
+ )
208
+ return new_compile_spec
0 commit comments