88# Example script for exporting simple models to flatbuffer
99
1010import argparse
11+ import json
1112import logging
1213import os
13- from typing import Optional
1414
15- import torch
15+ from pathlib import Path
16+ from typing import Optional , Tuple
1617
18+ import torch
1719from executorch .backends .arm .arm_backend import ArmCompileSpecBuilder
1820from executorch .backends .arm .arm_partitioner import ArmPartitioner
1921from executorch .backends .arm .quantizer .arm_quantizer import (
2022 ArmQuantizer ,
2123 get_symmetric_quantization_config ,
2224)
25+ from executorch .backends .arm .util .arm_model_evaluator import GenericModelEvaluator
2326
2427from executorch .devtools .backend_debug import get_delegation_info
2528from executorch .exir import EdgeCompileConfig , ExecutorchBackendConfig
@@ -151,6 +154,8 @@ def forward(self, x):
151154 "softmax" : SoftmaxModule ,
152155}
153156
157+ evaluators = {}
158+
154159targets = [
155160 "ethos-u55-32" ,
156161 "ethos-u55-64" ,
@@ -202,6 +207,37 @@ def get_compile_spec(target: str, intermediates: bool) -> ArmCompileSpecBuilder:
202207 return spec_builder .build ()
203208
204209
210+ def get_evaluator (model_name : str ) -> GenericModelEvaluator :
211+ if model_name not in evaluators :
212+ return GenericModelEvaluator
213+ else :
214+ return evaluators [model_name ]
215+
216+
217+ def evaluate_model (
218+ model_name : str ,
219+ intermediates : str ,
220+ model_fp32 : torch .nn .Module ,
221+ model_int8 : torch .nn .Module ,
222+ example_inputs : Tuple [torch .Tensor ],
223+ ):
224+ evaluator = get_evaluator (model_name )
225+
226+ # Get the path of the TOSA flatbuffer that is dumped
227+ intermediates_path = Path (intermediates )
228+ tosa_paths = list (intermediates_path .glob ("*.tosa" ))
229+
230+ init_evaluator = evaluator (
231+ model_name , model_fp32 , model_int8 , example_inputs , str (tosa_paths [0 ])
232+ )
233+
234+ quant_metrics = init_evaluator .evaluate ()
235+ output_json_path = intermediates_path / "quant_metrics.json"
236+
237+ with output_json_path .open ("w" ) as json_file :
238+ json .dump (quant_metrics , json_file )
239+
240+
205241def dump_delegation_info (edge , intermediate_files_folder : Optional [str ] = None ):
206242 graph_module = edge .exported_program ().graph_module
207243 delegation_info = get_delegation_info (graph_module )
@@ -242,6 +278,14 @@ def get_args():
242278 choices = targets ,
243279 help = f"For ArmBackend delegated models, pick the target, and therefore the instruction set generated. valid targets are { targets } " ,
244280 )
281+ parser .add_argument (
282+ "-e" ,
283+ "--evaluate" ,
284+ action = "store_true" ,
285+ required = False ,
286+ default = False ,
287+ help = "Flag for running evaluation of the model." ,
288+ )
245289 parser .add_argument (
246290 "-q" ,
247291 "--quantize" ,
@@ -275,11 +319,11 @@ def get_args():
275319 help = "Location for outputs, if not the default of cwd." ,
276320 )
277321 args = parser .parse_args ()
278- return args
279322
280-
281- if __name__ == "__main__" :
282- args = get_args ()
323+ if args .evaluate and (args .quantize is None or args .intermediates is None ):
324+ raise RuntimeError (
325+ "--evaluate requires --quantize and --intermediates to be enabled."
326+ )
283327
284328 if args .debug :
285329 logging .basicConfig (level = logging .DEBUG , format = FORMAT , force = True )
@@ -302,16 +346,26 @@ def get_args():
302346 ):
303347 raise RuntimeError (f"Model { args .model_name } cannot be delegated." )
304348
349+ return args
350+
351+
352+ if __name__ == "__main__" :
353+ args = get_args ()
354+
305355 # Pick model from one of the supported lists
306356 model , example_inputs = get_model_and_inputs_from_name (args .model_name )
307357 model = model .eval ()
308358
359+ model_fp32 = model
360+
309361 # pre-autograd export. eventually this will become torch.export
310362 model = torch .export .export_for_training (model , example_inputs ).module ()
311363
312364 # Quantize if required
365+ model_int8 = None
313366 if args .quantize :
314367 model = quantize (model , example_inputs )
368+ model_int8 = model
315369
316370 edge = export_to_edge (
317371 model ,
@@ -361,3 +415,8 @@ def get_args():
361415 output_name = os .path .join (args .output , output_name )
362416
363417 save_pte_program (exec_prog , output_name )
418+
419+ if args .evaluate :
420+ evaluate_model (
421+ args .model_name , args .intermediates , model_fp32 , model_int8 , example_inputs
422+ )
0 commit comments