1515from typing import Any , Tuple
1616
1717from executorch .backends .cadence .aot .compiler import (
18+ _lower_ep_to_cadence_gen_etrecord ,
1819 convert_pt2 ,
19- export_to_executorch_gen_etrecord ,
2020 fuse_pt2 ,
2121 prepare_pt2 ,
2222)
@@ -38,13 +38,12 @@ def export_model(
3838 model : nn .Module ,
3939 example_inputs : Tuple [Any , ...],
4040 file_name : str = "CadenceDemoModel" ,
41- run_and_compare : bool = True ,
42- eps_error : float = 1e-1 ,
43- eps_warn : float = 1e-5 ,
41+ working_dir : Optional [str ] = None ,
4442):
4543 # create work directory for outputs and model binary
46- working_dir = tempfile .mkdtemp (dir = "/tmp" )
47- logging .debug (f"Created work directory { working_dir } " )
44+ if working_dir is None :
45+ working_dir = tempfile .mkdtemp (dir = "/tmp" )
46+ logging .debug (f"Created work directory { working_dir } " )
4847
4948 # Instantiate the quantizer
5049 quantizer = CadenceDefaultQuantizer ()
@@ -66,9 +65,11 @@ def export_model(
6665 # the one used in prepare_and_convert_pt2)
6766 quantized_model = fuse_pt2 (converted_model , quantizer )
6867
68+ ep = torch .export .export (quantized_model , example_inputs , strict = True )
69+
6970 # Get edge program after Cadence specific passes
70- exec_prog : ExecutorchProgramManager = export_to_executorch_gen_etrecord (
71- quantized_model , example_inputs , output_dir = working_dir
71+ exec_prog : ExecutorchProgramManager = _lower_ep_to_cadence_gen_etrecord (
72+ ep , output_dir = working_dir
7273 )
7374
7475 logging .info ("Final exported graph:\n " )
@@ -92,13 +93,24 @@ def export_model(
9293 f"Executorch bundled program buffer saved to { file_name } is { len (buffer )} total bytes"
9394 )
9495
95- # TODO: move to test infra
96- if run_and_compare :
97- runtime .run_and_compare (
98- executorch_prog = exec_prog ,
99- inputs = example_inputs ,
100- ref_outputs = ref_outputs ,
101- working_dir = working_dir ,
102- eps_error = eps_error ,
103- eps_warn = eps_warn ,
104- )
96+
97+ def export_and_run_model (
98+ model : nn .Module ,
99+ example_inputs : Tuple [Any , ...],
100+ file_name : str = "CadenceDemoModel" ,
101+ eps_error : float = 1e-1 ,
102+ eps_warn : float = 1e-5 ,
103+ ):
104+ # create work directory for outputs and model binary
105+ working_dir = tempfile .mkdtemp (dir = "/tmp" )
106+ logging .debug (f"Created work directory { working_dir } " )
107+ exec_prog = export_model (model , example_inputs , file_name , working_dir )
108+ ref_outputs = model (* example_inputs )
109+ runtime .run_and_compare (
110+ executorch_prog = exec_prog ,
111+ inputs = example_inputs ,
112+ ref_outputs = ref_outputs ,
113+ working_dir = working_dir ,
114+ eps_error = eps_error ,
115+ eps_warn = eps_warn ,
116+ )
0 commit comments