1010
1111import logging
1212import tempfile
13+ import torch
1314
1415from executorch .backends .cadence .aot .ops_registrations import * # noqa
15- from typing import Any , Tuple
16+ from typing import Any , Optional , Tuple
1617
1718from executorch .backends .cadence .aot .compiler import (
19+ _lower_ep_to_cadence_gen_etrecord ,
1820 convert_pt2 ,
19- export_to_executorch_gen_etrecord ,
2021 fuse_pt2 ,
2122 prepare_pt2 ,
2223)
@@ -38,13 +39,12 @@ def export_model(
3839 model : nn .Module ,
3940 example_inputs : Tuple [Any , ...],
4041 file_name : str = "CadenceDemoModel" ,
41- run_and_compare : bool = True ,
42- eps_error : float = 1e-1 ,
43- eps_warn : float = 1e-5 ,
42+ working_dir : Optional [str ] = None ,
4443):
4544 # create work directory for outputs and model binary
46- working_dir = tempfile .mkdtemp (dir = "/tmp" )
47- logging .debug (f"Created work directory { working_dir } " )
45+ if working_dir is None :
46+ working_dir = tempfile .mkdtemp (dir = "/tmp" )
47+ logging .debug (f"Created work directory { working_dir } " )
4848
4949 # Instantiate the quantizer
5050 quantizer = CadenceDefaultQuantizer ()
@@ -66,9 +66,11 @@ def export_model(
6666 # the one used in prepare_and_convert_pt2)
6767 quantized_model = fuse_pt2 (converted_model , quantizer )
6868
69+ ep = torch .export .export (quantized_model , example_inputs , strict = True )
70+
6971 # Get edge program after Cadence specific passes
70- exec_prog : ExecutorchProgramManager = export_to_executorch_gen_etrecord (
71- quantized_model , example_inputs , output_dir = working_dir
72+ exec_prog : ExecutorchProgramManager = _lower_ep_to_cadence_gen_etrecord (
73+ ep , output_dir = working_dir
7274 )
7375
7476 logging .info ("Final exported graph:\n " )
@@ -92,13 +94,24 @@ def export_model(
9294 f"Executorch bundled program buffer saved to { file_name } is { len (buffer )} total bytes"
9395 )
9496
95- # TODO: move to test infra
96- if run_and_compare :
97- runtime .run_and_compare (
98- executorch_prog = exec_prog ,
99- inputs = example_inputs ,
100- ref_outputs = ref_outputs ,
101- working_dir = working_dir ,
102- eps_error = eps_error ,
103- eps_warn = eps_warn ,
104- )
97+
98+ def export_and_run_model (
99+ model : nn .Module ,
100+ example_inputs : Tuple [Any , ...],
101+ file_name : str = "CadenceDemoModel" ,
102+ eps_error : float = 1e-1 ,
103+ eps_warn : float = 1e-5 ,
104+ ):
105+ # create work directory for outputs and model binary
106+ working_dir = tempfile .mkdtemp (dir = "/tmp" )
107+ logging .debug (f"Created work directory { working_dir } " )
108+ exec_prog = export_model (model , example_inputs , file_name , working_dir )
109+ ref_outputs = model (* example_inputs )
110+ runtime .run_and_compare (
111+ executorch_prog = exec_prog ,
112+ inputs = example_inputs ,
113+ ref_outputs = ref_outputs ,
114+ working_dir = working_dir ,
115+ eps_error = eps_error ,
116+ eps_warn = eps_warn ,
117+ )
0 commit comments