1111import logging
1212import tempfile
1313
14+ import torch
15+
1416from executorch .backends .cadence .aot .ops_registrations import * # noqa
15- from typing import Any , Tuple
17+ from typing import Any , Optional , Tuple
1618
1719from executorch .backends .cadence .aot .compiler import (
20+ _lower_ep_to_cadence_gen_etrecord ,
1821 convert_pt2 ,
19- export_to_executorch_gen_etrecord ,
2022 fuse_pt2 ,
2123 prepare_pt2 ,
2224)
@@ -38,13 +40,12 @@ def export_model(
3840 model : nn .Module ,
3941 example_inputs : Tuple [Any , ...],
4042 file_name : str = "CadenceDemoModel" ,
41- run_and_compare : bool = True ,
42- eps_error : float = 1e-1 ,
43- eps_warn : float = 1e-5 ,
43+ working_dir : Optional [str ] = None ,
4444):
4545 # create work directory for outputs and model binary
46- working_dir = tempfile .mkdtemp (dir = "/tmp" )
47- logging .debug (f"Created work directory { working_dir } " )
46+ if working_dir is None :
47+ working_dir = tempfile .mkdtemp (dir = "/tmp" )
48+ logging .debug (f"Created work directory { working_dir } " )
4849
4950 # Instantiate the quantizer
5051 quantizer = CadenceDefaultQuantizer ()
@@ -66,9 +67,11 @@ def export_model(
6667 # the one used in prepare_and_convert_pt2)
6768 quantized_model = fuse_pt2 (converted_model , quantizer )
6869
70+ ep = torch .export .export (quantized_model , example_inputs , strict = True )
71+
6972 # Get edge program after Cadence specific passes
70- exec_prog : ExecutorchProgramManager = export_to_executorch_gen_etrecord (
71- quantized_model , example_inputs , output_dir = working_dir
73+ exec_prog : ExecutorchProgramManager = _lower_ep_to_cadence_gen_etrecord (
74+ ep , output_dir = working_dir
7275 )
7376
7477 logging .info ("Final exported graph:\n " )
@@ -92,13 +95,24 @@ def export_model(
9295 f"Executorch bundled program buffer saved to { file_name } is { len (buffer )} total bytes"
9396 )
9497
95- # TODO: move to test infra
96- if run_and_compare :
97- runtime .run_and_compare (
98- executorch_prog = exec_prog ,
99- inputs = example_inputs ,
100- ref_outputs = ref_outputs ,
101- working_dir = working_dir ,
102- eps_error = eps_error ,
103- eps_warn = eps_warn ,
104- )
98+
99+ def export_and_run_model (
100+ model : nn .Module ,
101+ example_inputs : Tuple [Any , ...],
102+ file_name : str = "CadenceDemoModel" ,
103+ eps_error : float = 1e-1 ,
104+ eps_warn : float = 1e-5 ,
105+ ):
106+ # create work directory for outputs and model binary
107+ working_dir = tempfile .mkdtemp (dir = "/tmp" )
108+ logging .debug (f"Created work directory { working_dir } " )
109+ exec_prog = export_model (model , example_inputs , file_name , working_dir )
110+ ref_outputs = model (* example_inputs )
111+ runtime .run_and_compare (
112+ executorch_prog = exec_prog ,
113+ inputs = example_inputs ,
114+ ref_outputs = ref_outputs ,
115+ working_dir = working_dir ,
116+ eps_error = eps_error ,
117+ eps_warn = eps_warn ,
118+ )
0 commit comments