@@ -90,6 +90,18 @@ def get_random_inputs(self) -> Sequence[torch.Tensor]:
9090 return (torch .ones (n , n , n ), 2 * torch .ones (n , n , n ), 3 * torch .ones (n , n , n ))
9191
9292
93+ class ModuleLinear (torch .nn .Module ):
94+ def __init__ (self ):
95+ super ().__init__ ()
96+ self .linear = torch .nn .Linear (3 , 3 )
97+
98+ def forward (self , x : torch .Tensor ):
99+ return self .linear (x )
100+
101+ def get_random_inputs (self ):
102+ return (torch .randn (3 ),)
103+
104+
93105#
94106# Backends
95107#
@@ -116,24 +128,23 @@ def export_module_to_program(
116128 extract_delegate_segments : bool ,
117129 constant_tensor_alignment : Optional [int ] = None ,
118130 delegate_alignment : Optional [int ] = None ,
119- method : str = "forward" ,
131+ method_name : str = "forward" ,
120132) -> ExecutorchProgramManager :
121133 eager_module = module_class ().eval ()
122134 inputs = ()
123135 if hasattr (eager_module , "get_random_inputs" ):
124136 inputs = eager_module .get_random_inputs () # type: ignore[operator]
125137
126138 class WrapperModule (torch .nn .Module ):
127- def __init__ (self , fn ):
139+ def __init__ (self , fn , method_name = method_name ):
128140 super ().__init__ ()
129141 self .fn = fn
142+ self .method_name = method_name
130143
131144 def forward (self , * args , ** kwargs ):
132- return self .fn (* args , ** kwargs )
145+ return getattr ( self .fn , self . method_name ) (* args , ** kwargs )
133146
134- exported_program = export (
135- WrapperModule (getattr (eager_module , method )), args = inputs , strict = True
136- )
147+ exported_program = export (WrapperModule (eager_module ), args = inputs , strict = True )
137148
138149 edge_config = EdgeCompileConfig (_check_ir_validity = False )
139150 et_config = exir .ExecutorchBackendConfig (
0 commit comments