1515import executorch .exir as exir
1616import torch
1717from executorch .exir import ExecutorchBackendConfig , ExecutorchProgramManager , to_edge
18+ from executorch .exir .capture ._capture import patch_forward
1819from executorch .exir .dynamic_shape import DynamicMemoryPlanningMode
1920from executorch .exir .passes import (
2021 DebugPass ,
@@ -70,6 +71,7 @@ def export(
7071 export_joint_graph : bool = False ,
7172 external_constants : bool = False ,
7273 export_state_names : bool = False ,
74+ share_mutable_buffers : bool = False ,
7375 ) -> "ExportedModule" :
7476 """
7577 Creates a new ExportedModule for the specified module class.
@@ -134,10 +136,13 @@ def return_wrapper():
134136 # all exported methods must have the same signature so just pick the first one.
135137 methods [0 ],
136138 )
137- trace_inputs : Sequence = get_trace_inputs ()
139+ inputs = get_trace_inputs ()
138140 method_name_to_args = {}
139141 for method in methods :
140- method_name_to_args [method ] = trace_inputs
142+ if hasattr (eager_module , "get_random_inputs_per_method" ):
143+ # pyre-ignore
144+ inputs = eager_module .get_random_inputs_per_method ()[method ] # type: ignore[operator]
145+ method_name_to_args [method ] = inputs
141146
142147 method_name_to_dynamic_shapes = None
143148 if hasattr (eager_module , "get_dynamic_shapes" ):
@@ -149,23 +154,18 @@ def return_wrapper():
149154 method_name_to_dynamic_shapes [method ] = trace_dynamic_shapes
150155
151156 memory_planning_pass = MemoryPlanningPass (
152- alloc_mutable_buffers = not export_state_names
157+ alloc_mutable_buffers = not export_state_names ,
158+ share_mutable_buffers = share_mutable_buffers ,
153159 )
154160 if hasattr (eager_module , "get_memory_planning_pass" ):
155161 memory_planning_pass = eager_module .get_memory_planning_pass () # type: ignore[operator]
156162
157- class WrapperModule (nn .Module ):
158- def __init__ (self , method ):
159- super ().__init__ ()
160- self .forward = method
161-
162163 exported_methods = {}
163164 # These cleanup passes are required to convert the `add` op to its out
164165 # variant, along with some other transformations.
165166 for method_name , method_input in method_name_to_args .items ():
166167 # if not isinstance(eager_module, torch.nn.Module):
167168 if export_joint_graph :
168- # _export was having issues with WrapperModule.
169169 assert method_name == "forward"
170170 ep = _export (
171171 eager_module ,
@@ -179,15 +179,16 @@ def __init__(self, method):
179179 )
180180 exported_methods [method_name ] = _export_forward_backward (ep )
181181 else :
182- exported_methods [method_name ] = export (
183- eager_module ,
184- method_input , # type: ignore[arg-type]
185- dynamic_shapes = (
186- method_name_to_dynamic_shapes [method_name ]
187- if method_name_to_dynamic_shapes
188- else None
189- ),
190- )
182+ with patch_forward (eager_module , getattr (eager_module , method_name )):
183+ exported_methods [method_name ] = export (
184+ eager_module ,
185+ method_input , # type: ignore[arg-type]
186+ dynamic_shapes = (
187+ method_name_to_dynamic_shapes [method_name ]
188+ if method_name_to_dynamic_shapes
189+ else None
190+ ),
191+ )
191192
192193 exec_prog = to_edge (
193194 exported_methods ,
@@ -229,6 +230,6 @@ def __init__(self, method):
229230 methods = methods ,
230231 executorch_program = exec_prog ,
231232 exported_program = exported_program ,
232- trace_inputs = trace_inputs ,
233+ trace_inputs = inputs ,
233234 get_random_inputs_fn = get_random_inputs_fn ,
234235 )
0 commit comments