1010import tempfile
1111
1212from executorch .backends .cadence .aot .ops_registrations import * # noqa
13- import os
1413from typing import Any , Tuple
1514
1615from executorch .backends .cadence .aot .compiler import (
1716 convert_pt2 ,
1817 export_to_cadence ,
19- export_to_edge ,
20- quantize_pt2 ,
18+ fuse_pt2 ,
2119)
2220from executorch .backends .cadence .aot .quantizer .quantizer import CadenceQuantizer
2321from executorch .backends .cadence .runtime import runtime
2422from executorch .backends .cadence .runtime .executor import BundledProgramManager
2523from executorch .exir import ExecutorchProgramManager
2624from torch import nn
2725
28- from .utils import print_ops_info
26+ from .utils import save_bpte_program , save_pte_program
2927
3028
3129FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
3230logging .basicConfig (level = logging .INFO , format = FORMAT )
3331
3432
35- def _save_pte_program (
36- prog : ExecutorchProgramManager , model_name : str , output_dir : str = ""
37- ) -> None :
38- if model_name .endswith (".pte" ):
39- filename = model_name
40- else :
41- filename = os .path .join (output_dir , f"{ model_name } .pte" )
42-
43- try :
44- with open (filename , "wb" ) as file :
45- prog .write_to_file (file )
46- logging .info (f"Saved exported program to { filename } " )
47- except Exception as e :
48- logging .error (f"Error while saving to { filename } : { e } " )
49-
50-
51- def _save_bpte_program (
52- buffer : bytes ,
53- model_name : str ,
54- output_dir : str = "" ,
55- ) -> None :
56- if model_name .endswith (".bpte" ):
57- filename = model_name
58- else :
59- filename = os .path .join (output_dir , f"{ model_name } .bpte" )
60- try :
61- with open (filename , "wb" ) as f :
62- f .write (buffer )
63- logging .info (f"Saved exported program to { filename } " )
64- except Exception as e :
65- logging .error (f"Error while saving to { output_dir } : { e } " )
66-
67-
6833def export_model (
6934 model : nn .Module ,
7035 example_inputs : Tuple [Any , ...],
@@ -74,32 +39,28 @@ def export_model(
7439 working_dir = tempfile .mkdtemp (dir = "/tmp" )
7540 logging .debug (f"Created work directory { working_dir } " )
7641
77- # convert the model (also called in quantize_pt2)
78- converted_model = convert_pt2 ( model , example_inputs , CadenceQuantizer () )
42+ # Instantiate the quantizer
43+ quantizer = CadenceQuantizer ()
7944
80- # Get reference outputs from quantized_model
81- ref_outputs = converted_model ( * example_inputs )
45+ # Convert the model
46+ converted_model = convert_pt2 ( model , example_inputs , quantizer )
8247
83- # Quantize the model
84- quantized_model = quantize_pt2 ( model , example_inputs )
48+ # Get reference outputs from converted model
49+ ref_outputs = converted_model ( * example_inputs )
8550
86- # Get edge program (also called in export_to_cadence)
87- edge_prog_manager = export_to_edge (quantized_model , example_inputs )
51+ # Quantize the model (note: quantizer needs to be the same as
52+ # the one used in convert_pt2)
53+ quantized_model = fuse_pt2 (converted_model , quantizer )
8854
8955 # Get edge program after Cadence specific passes
9056 cadence_prog_manager = export_to_cadence (quantized_model , example_inputs )
9157
58+ # Get executorch program after Cadence specific passes
9259 exec_prog : ExecutorchProgramManager = cadence_prog_manager .to_executorch ()
9360
9461 logging .info ("Final exported graph:\n " )
9562 exec_prog .exported_program ().graph_module .graph .print_tabular ()
9663
97- # Print some information to terminal
98- print_ops_info (
99- edge_prog_manager .exported_program ().graph_module ,
100- cadence_prog_manager .exported_program ().graph_module ,
101- )
102-
10364 forward_test_data = BundledProgramManager .bundled_program_test_data_gen (
10465 method = "forward" , inputs = example_inputs , expected_outputs = ref_outputs
10566 )
@@ -110,9 +71,9 @@ def export_model(
11071 forward_test_data ,
11172 )
11273 # Save the program as pte (default name is CadenceDemoModel.pte)
113- _save_pte_program (exec_prog , file_name , working_dir )
74+ save_pte_program (exec_prog , file_name , working_dir )
11475 # Save the program as btpe (default name is CadenceDemoModel.bpte)
115- _save_bpte_program (buffer , file_name , working_dir )
76+ save_bpte_program (buffer , file_name , working_dir )
11677
11778 logging .debug (
11879 f"Executorch bundled program buffer saved to { file_name } is { len (buffer )} total bytes"
0 commit comments