33# Please refer to the license found in the LICENSE file in the root directory of the source tree.
44
55import argparse
6+ import collections
67import copy
78
89import pathlib
2324from executorch .exir import to_edge
2425
2526from executorch .exir .backend .backend_api import to_backend
26-
27- from torch .export import export
27+ from executorch .extension .export_util .utils import save_pte_program
2828
2929REPO_ROOT = pathlib .Path (__file__ ).resolve ().parent .parent .parent .parent .parent
3030EXAMPLES_DIR = REPO_ROOT / "examples"
4141)
4242
4343
44- def parse_args () -> argparse .ArgumentParser :
44+ def is_fbcode ():
45+ return not hasattr (torch .version , "git_version" )
46+
47+
48+ _CAN_RUN_WITH_PYBINDINGS = (sys .platform == "darwin" ) and not is_fbcode ()
49+ if _CAN_RUN_WITH_PYBINDINGS :
50+ from executorch .runtime import Runtime
51+
52+
53+ def parse_args () -> argparse .Namespace :
4554 parser = argparse .ArgumentParser ()
4655
4756 parser .add_argument (
@@ -82,9 +91,12 @@ def parse_args() -> argparse.ArgumentParser:
8291 required = False ,
8392 default = False ,
8493 )
94+ parser .add_argument (
95+ "--run_with_pybindings" ,
96+ action = argparse .BooleanOptionalAction ,
97+ )
8598
8699 args = parser .parse_args ()
87- # pyre-fixme[7]: Expected `ArgumentParser` but got `Namespace`.
88100 return args
89101
90102
@@ -95,7 +107,8 @@ def partition_module_to_coreml(module):
95107def lower_module_to_coreml (module , compile_specs , example_inputs ):
96108 module = module .eval ()
97109 edge = to_edge (
98- export (module , example_inputs , strict = True ), compile_config = _EDGE_COMPILE_CONFIG
110+ torch .export .export (module , example_inputs , strict = True ),
111+ compile_config = _EDGE_COMPILE_CONFIG ,
99112 )
100113 # All of the subsequent calls on the edge_dialect_graph generated above (such as delegation or
101114 # to_executorch()) are done in place and the graph is also modified in place. For debugging purposes
@@ -115,24 +128,23 @@ def lower_module_to_coreml(module, compile_specs, example_inputs):
115128def export_lowered_module_to_executorch_program (lowered_module , example_inputs ):
116129 lowered_module (* example_inputs )
117130 exec_prog = to_edge (
118- export (lowered_module , example_inputs , strict = True ),
131+ torch . export . export (lowered_module , example_inputs , strict = True ),
119132 compile_config = _EDGE_COMPILE_CONFIG ,
120133 ).to_executorch (config = exir .ExecutorchBackendConfig (extract_delegate_segments = True ))
121134
122135 return exec_prog
123136
124137
125- def save_executorch_program (exec_prog , model_name , compute_unit ):
126- buffer = exec_prog .buffer
127- filename = f"{ model_name } _coreml_{ compute_unit } .pte"
128- print (f"Saving exported program to { filename } " )
129- with open (filename , "wb" ) as file :
130- file .write (buffer )
131- return
138+ def get_pte_base_name (args : argparse .Namespace ) -> str :
139+ pte_name = args .model_name
140+ if args .compile :
141+ pte_name += "_compiled"
142+ pte_name = f"{ pte_name } _coreml_{ args .compute_unit } "
143+ return pte_name
132144
133145
134- def save_processed_bytes (processed_bytes , model_name , compute_unit ):
135- filename = f"{ model_name } _coreml_ { compute_unit } .bin"
146+ def save_processed_bytes (processed_bytes , base_name : str ):
147+ filename = f"{ base_name } .bin"
136148 print (f"Saving processed bytes to { filename } " )
137149 with open (filename , "wb" ) as file :
138150 file .write (processed_bytes )
@@ -154,6 +166,37 @@ def generate_compile_specs_from_args(args):
154166 )
155167
156168
169+ def run_with_pybindings (executorch_program , eager_reference , example_inputs , precision ):
170+ if not _CAN_RUN_WITH_PYBINDINGS :
171+ raise RuntimeError ("Cannot run with pybindings on this platform." )
172+
173+ dtype = {
174+ "float32" : torch .float32 ,
175+ "float16" : torch .float16 ,
176+ }[precision ]
177+
178+ runtime = Runtime .get ()
179+ program = runtime .load_program (executorch_program .buffer )
180+ method = program .load_method ("forward" )
181+ et_outputs = method .execute (* example_inputs )[0 ]
182+ eager_outputs = eager_reference (* example_inputs )
183+ if isinstance (eager_outputs , collections .OrderedDict ):
184+ eager_outputs = eager_outputs ["out" ]
185+ if isinstance (eager_outputs , list | tuple ):
186+ eager_outputs = eager_outputs [0 ]
187+
188+ mse = ((et_outputs - eager_outputs ) ** 2 ).mean ().sqrt ()
189+ print (f"Mean square error: { mse } " )
190+ assert mse < 0.1 , "Mean square error is too high."
191+
192+ if dtype == torch .float32 :
193+ assert torch .allclose (
194+ et_outputs , eager_outputs , atol = 1e-02 , rtol = 1e-02
195+ ), f"""Outputs do not match eager reference:
196+ \t et_outputs (first 5)={ et_outputs .reshape (- 1 )[0 :5 ]}
197+ \t eager_outputs (first 5)={ eager_outputs .reshape (- 1 )[0 :5 ]} """
198+
199+
157200def main ():
158201 args = parse_args ()
159202
@@ -170,49 +213,65 @@ def main():
170213 f"Valid compute units are { valid_compute_units } ."
171214 )
172215
173- model , example_inputs , _ , dynamic_shapes = EagerModelFactory . create_model (
174- * MODEL_NAME_TO_MODEL [args .model_name ]
216+ model , example_args , example_kwargs , dynamic_shapes = (
217+ EagerModelFactory . create_model ( * MODEL_NAME_TO_MODEL [args .model_name ])
175218 )
176219 if not args .dynamic_shapes :
177220 dynamic_shapes = None
178221
179222 compile_specs = generate_compile_specs_from_args (args )
180- lowered_module = None
181-
223+ pte_base_name = get_pte_base_name (args )
182224 if args .use_partitioner :
183- model .eval ()
184- exir_program_aten = torch .export .export (
185- model , example_inputs , dynamic_shapes = dynamic_shapes , strict = True
186- )
187-
188- edge_program_manager = exir .to_edge (exir_program_aten )
189- edge_copy = copy .deepcopy (edge_program_manager )
190- partitioner = CoreMLPartitioner (
191- skip_ops_for_coreml_delegation = None , compile_specs = compile_specs
225+ model = model .eval ()
226+ assert not args .generate_etrecord , "ETRecord is not supported with partitioner"
227+ ep = torch .export .export (
228+ model ,
229+ args = example_args ,
230+ kwargs = example_kwargs ,
231+ dynamic_shapes = dynamic_shapes ,
192232 )
193- delegated_program_manager = edge_program_manager .to_backend (partitioner )
194- exec_program = delegated_program_manager .to_executorch (
195- config = exir .ExecutorchBackendConfig (extract_delegate_segments = True )
233+ print (ep )
234+ delegated_program = exir .to_edge_transform_and_lower (
235+ ep ,
236+ partitioner = [CoreMLPartitioner (compile_specs = compile_specs )],
196237 )
238+ exec_program = delegated_program .to_executorch ()
239+ save_pte_program (exec_program , pte_base_name )
240+ if args .run_with_pybindings :
241+ run_with_pybindings (
242+ executorch_program = exec_program ,
243+ eager_reference = model ,
244+ example_inputs = example_args ,
245+ precision = args .compute_precision ,
246+ )
197247 else :
198248 lowered_module , edge_copy = lower_module_to_coreml (
199249 module = model ,
200- example_inputs = example_inputs ,
250+ example_inputs = example_args ,
201251 compile_specs = compile_specs ,
202252 )
203253 exec_program = export_lowered_module_to_executorch_program (
204254 lowered_module ,
205- example_inputs ,
206- )
207-
208- model_name = f"{ args .model_name } _compiled" if args .compile else args .model_name
209- save_executorch_program (exec_program , model_name , args .compute_unit )
210- generate_etrecord (f"{ args .model_name } _coreml_etrecord.bin" , edge_copy , exec_program )
211-
212- if args .save_processed_bytes and lowered_module is not None :
213- save_processed_bytes (
214- lowered_module .processed_bytes , args .model_name , args .compute_unit
255+ example_args ,
215256 )
257+ save_pte_program (exec_program , pte_base_name )
258+ if args .generate_etrecord :
259+ generate_etrecord (
260+ f"{ args .model_name } _coreml_etrecord.bin" , edge_copy , exec_program
261+ )
262+
263+ if args .save_processed_bytes :
264+ save_processed_bytes (
265+ lowered_module .processed_bytes ,
266+ pte_base_name ,
267+ )
268+ if args .run_with_pybindings :
269+ run_with_pybindings (
270+ executorch_program = exec_program ,
271+ eager_reference = model ,
272+ example_inputs = example_args ,
273+ precision = args .compute_precision ,
274+ )
216275
217276
218277if __name__ == "__main__" :
0 commit comments