44# This source code is licensed under the BSD-style license found in the
55# LICENSE file in the root directory of this source tree.
66
7+ # pyre-unsafe
8+
79import argparse
810import inspect
911import os
1921from executorch .exir .backend .test .backend_with_compiler_demo import (
2022 BackendWithCompilerDemo ,
2123)
24+ from executorch .exir .program import ExecutorchProgramManager
2225from torch import nn
2326from torch .export import export
2427
@@ -111,10 +114,10 @@ def export_module_to_program(
111114 * ,
112115 backend_id : str ,
113116 extract_delegate_segments : bool ,
114- constant_tensor_alignemnt : Optional [int ] = None ,
117+ constant_tensor_alignment : Optional [int ] = None ,
115118 delegate_alignment : Optional [int ] = None ,
116119 method : str = "forward" ,
117- ) -> bytes :
120+ ) -> ExecutorchProgramManager :
118121 eager_module = module_class ().eval ()
119122 inputs = ()
120123 if hasattr (eager_module , "get_random_inputs" ):
@@ -135,7 +138,7 @@ def forward(self, *args, **kwargs):
135138 edge_config = EdgeCompileConfig (_check_ir_validity = False )
136139 et_config = exir .ExecutorchBackendConfig (
137140 extract_delegate_segments = extract_delegate_segments ,
138- constant_tensor_alignment = constant_tensor_alignemnt ,
141+ constant_tensor_alignment = constant_tensor_alignment ,
139142 delegate_alignment = delegate_alignment ,
140143 )
141144
@@ -170,7 +173,7 @@ def forward(self, *args, **kwargs):
170173 export (composite_module , args = inputs , strict = True )
171174 ).to_executorch (config = et_config )
172175
173- return executorch_program . buffer
176+ return executorch_program
174177
175178
176179def main () -> None :
@@ -199,6 +202,14 @@ def main() -> None:
199202 help = "ID of the backend to use for delegation; "
200203 + f"one of { known_backend_ids } " ,
201204 )
205+ parser .add_argument (
206+ "--inline_delegate_segments" ,
207+ action = "store_true" ,
208+ help = "Store delegate data inside the flatbuffer." ,
209+ )
210+ parser .add_argument (
211+ "--delegate_alignment" , type = int , default = None , help = "Delegate alignment."
212+ )
202213 parser .add_argument (
203214 "--outdir" ,
204215 type = str ,
@@ -219,25 +230,22 @@ def main() -> None:
219230
220231 # Export and write to the output files.
221232 os .makedirs (args .outdir , exist_ok = True )
233+ suffix = ""
222234 for module_name , module_class in module_names_to_classes .items ():
223- for extract_delegate_segments in (True , False ):
224- suffix = "" if extract_delegate_segments else "-nosegments"
225- # Create files with the default alignment, and a large alignment.
226- # This alignment should be so large that it's extremely unlikely for
227- # the data to accidentally be aligned to it in the default case.
228- for delegate_alignment in (None , 1024 ):
229- suffix += f"-da{ delegate_alignment } " if delegate_alignment else ""
230- outfile = os .path .join (args .outdir , f"{ module_name } { suffix } .pte" )
231- with open (outfile , "wb" ) as fp :
232- fp .write (
233- export_module_to_program (
234- module_class ,
235- backend_id = args .backend_id ,
236- extract_delegate_segments = extract_delegate_segments ,
237- delegate_alignment = delegate_alignment ,
238- )
239- )
240- print (f"Exported { module_name } and wrote program data to { outfile } " )
235+ if args .inline_delegate_segments :
236+ suffix += "-nosegments"
237+ if args .delegate_alignment is not None :
238+ suffix += f"-da{ args .delegate_alignment } "
239+ outfile = os .path .join (args .outdir , f"{ module_name } { suffix } .pte" )
240+ executorch_program = export_module_to_program (
241+ module_class ,
242+ backend_id = args .backend_id ,
243+ extract_delegate_segments = not args .inline_delegate_segments ,
244+ delegate_alignment = args .delegate_alignment ,
245+ )
246+ with open (outfile , "wb" ) as fp :
247+ fp .write (executorch_program .buffer )
248+ print (f"Exported { module_name } and wrote program data to { outfile } " )
241249
242250
243251if __name__ == "__main__" :
0 commit comments