99
1010import argparse
1111import logging
12+ import os
1213
1314import torch
1415
@@ -48,6 +49,19 @@ def get_model_and_inputs_from_name(model_name: str):
4849 model , example_inputs , _ = EagerModelFactory .create_model (
4950 * MODEL_NAME_TO_MODEL [model_name ]
5051 )
52+ # Case 3: Model is in an external python file loaded as a module.
53+ # ModelUnderTest should be a torch.nn.module instance
54+ # ModelInputs should be a tuple of inputs to the forward function
55+ elif model_name .endswith (".py" ):
56+ import importlib .util
57+
58+ # load model's module and add it
59+ spec = importlib .util .spec_from_file_location ("tmp_model" , model_name )
60+ module = importlib .util .module_from_spec (spec )
61+ spec .loader .exec_module (module )
62+ model = module .ModelUnderTest
63+ example_inputs = module .ModelInputs
64+
5165 else :
5266 raise RuntimeError (
5367 f"Model '{ model_name } ' is not a valid name. Use --help for a list of available models."
@@ -133,7 +147,51 @@ def forward(self, x):
133147 "softmax" : SoftmaxModule ,
134148}
135149
136- if __name__ == "__main__" :
150+ targets = [
151+ "ethos-u85-128" ,
152+ "ethos-u55-128" ,
153+ "TOSA" ,
154+ ]
155+
156+
157+ def get_compile_spec (target : str , intermediates : bool ) -> ArmCompileSpecBuilder :
158+ spec_builder = None
159+ if target == "TOSA" :
160+ spec_builder = (
161+ ArmCompileSpecBuilder ().tosa_compile_spec ().set_permute_memory_format (True )
162+ )
163+ elif target == "ethos-u55-128" :
164+ spec_builder = (
165+ ArmCompileSpecBuilder ()
166+ .ethosu_compile_spec (
167+ "ethos-u55-128" ,
168+ system_config = "Ethos_U55_High_End_Embedded" ,
169+ memory_mode = "Shared_Sram" ,
170+ extra_flags = "--debug-force-regor --output-format=raw" ,
171+ )
172+ .set_permute_memory_format (args .model_name in MODEL_NAME_TO_MODEL .keys ())
173+ .set_quantize_io (True )
174+ )
175+ elif target == "ethos-u85-128" :
176+ spec_builder = (
177+ ArmCompileSpecBuilder ()
178+ .ethosu_compile_spec (
179+ "ethos-u85-128" ,
180+ system_config = "Ethos_U85_SYS_DRAM_Mid" ,
181+ memory_mode = "Shared_Sram" ,
182+ extra_flags = "--output-format=raw" ,
183+ )
184+ .set_permute_memory_format (True )
185+ .set_quantize_io (True )
186+ )
187+
188+ if intermediates is not None :
189+ spec_builder .dump_intermediate_artifacts_to (args .intermediates )
190+
191+ return spec_builder .build ()
192+
193+
194+ def get_args ():
137195 parser = argparse .ArgumentParser ()
138196 parser .add_argument (
139197 "-m" ,
@@ -149,6 +207,15 @@ def forward(self, x):
149207 default = False ,
150208 help = "Flag for producing ArmBackend delegated model" ,
151209 )
210+ parser .add_argument (
211+ "-t" ,
212+ "--target" ,
213+ action = "store" ,
214+ required = False ,
215+ default = "ethos-u55-128" ,
216+ choices = targets ,
217+ help = f"For ArmBackend delegated models, pick the target, and therefore the instruction set generated. valid targets are { targets } " ,
218+ )
152219 parser .add_argument (
153220 "-q" ,
154221 "--quantize" ,
@@ -167,8 +234,26 @@ def forward(self, x):
167234 parser .add_argument (
168235 "--debug" , action = "store_true" , help = "Set the logging level to debug."
169236 )
170-
237+ parser .add_argument (
238+ "-i" ,
239+ "--intermediates" ,
240+ action = "store" ,
241+ required = False ,
242+ help = "Store intermediate output (like TOSA artefacts) somewhere." ,
243+ )
244+ parser .add_argument (
245+ "-o" ,
246+ "--output" ,
247+ action = "store" ,
248+ required = False ,
249+ help = "Location for outputs, if not the default of cwd." ,
250+ )
171251 args = parser .parse_args ()
252+ return args
253+
254+
255+ if __name__ == "__main__" :
256+ args = get_args ()
172257
173258 if args .debug :
174259 logging .basicConfig (level = logging .DEBUG , format = FORMAT , force = True )
@@ -191,7 +276,7 @@ def forward(self, x):
191276 ):
192277 raise RuntimeError (f"Model { args .model_name } cannot be delegated." )
193278
194- # 1. pick model from one of the supported lists
279+ # Pick model from one of the supported lists
195280 model , example_inputs = get_model_and_inputs_from_name (args .model_name )
196281 model = model .eval ()
197282
@@ -209,23 +294,18 @@ def forward(self, x):
209294 _check_ir_validity = False ,
210295 ),
211296 )
297+
298+ # As we can target multiple output encodings from ArmBackend, one must
299+ # be specified.
300+ compile_spec = (
301+ get_compile_spec (args .target , args .intermediates )
302+ if args .delegate is True
303+ else None
304+ )
305+
212306 logging .debug (f"Exported graph:\n { edge .exported_program ().graph } " )
213307 if args .delegate is True :
214- edge = edge .to_backend (
215- ArmPartitioner (
216- ArmCompileSpecBuilder ()
217- .ethosu_compile_spec (
218- "ethos-u55-128" ,
219- system_config = "Ethos_U55_High_End_Embedded" ,
220- memory_mode = "Shared_Sram" ,
221- )
222- .set_permute_memory_format (
223- args .model_name in MODEL_NAME_TO_MODEL .keys ()
224- )
225- .set_quantize_io (True )
226- .build ()
227- )
228- )
308+ edge = edge .to_backend (ArmPartitioner (compile_spec ))
229309 logging .debug (f"Lowered graph:\n { edge .exported_program ().graph } " )
230310
231311 try :
@@ -241,7 +321,12 @@ def forward(self, x):
241321 else :
242322 raise e
243323
244- model_name = f"{ args .model_name } " + (
245- "_arm_delegate" if args .delegate is True else ""
324+ model_name = os .path .basename (os .path .splitext (args .model_name )[0 ])
325+ output_name = f"{ model_name } " + (
326+ f"_arm_delegate_{ args .target } " if args .delegate is True else ""
246327 )
247- save_pte_program (exec_prog , model_name )
328+
329+ if args .output is not None :
330+ output_name = os .path .join (args .output , output_name )
331+
332+ save_pte_program (exec_prog , output_name )
0 commit comments