@@ -167,14 +167,15 @@ def _impl_unstable_to_stable_linear_to_functional_linear(self, gm):
167167 Returns:
168168 Modified GraphModule object
169169 """
170+ import torch .nn .functional as F
171+
170172 # Get reference to torch._C._nn.linear for comparison
171173 try :
172174 unstable_linear = torch ._C ._nn .linear
173175 except AttributeError :
174176 unstable_linear = None
175177
176178 # Traverse all nodes to find nodes that need to be replaced
177- nodes_to_replace = []
178179 for node in gm .graph .nodes :
179180 if node .op == "call_function" :
180181 target = node .target
@@ -203,98 +204,10 @@ def _impl_unstable_to_stable_linear_to_functional_linear(self, gm):
203204 should_replace = True
204205
205206 if should_replace :
206- nodes_to_replace .append (node )
207-
208- # Since torch._C._nn.linear and torch.nn.functional.linear are the same object,
209- # the code generator cannot distinguish them, so we need to use AST to modify the code string after code generation
210- if nodes_to_replace :
211- # First recompile to generate code
212- gm .recompile ()
213-
214- # Use AST to modify the generated code, replacing torch._C._nn.linear with torch.nn.functional.linear
215- code_str = gm .code
216-
217- # Parse AST
218- tree = ast .parse (code_str )
219-
220- class LinearReplacer (ast .NodeTransformer ):
221- def visit_Call (self , node ):
222- # Check if it's a torch._C._nn.linear call
223- # Structure: torch._C._nn.linear(...)
224- filtered_nodes = [
225- node
226- for node in [node ]
227- if isinstance (node .func , ast .Attribute )
228- if node .func .attr == "linear"
229- if isinstance (node .func .value , ast .Attribute )
230- if node .func .value .attr == "_nn"
231- if isinstance (node .func .value .value , ast .Attribute )
232- if node .func .value .value .attr == "_C"
233- if isinstance (node .func .value .value .value , ast .Name )
234- if node .func .value .value .value .id == "torch"
235- ]
236- if filtered_nodes :
237- # Found torch._C._nn.linear, replace with torch.nn.functional.linear
238- new_func = ast .Attribute (
239- value = ast .Attribute (
240- value = ast .Attribute (
241- value = ast .Name (
242- id = "torch" ,
243- ctx = ast .Load (),
244- ),
245- attr = "nn" ,
246- ctx = ast .Load (),
247- ),
248- attr = "functional" ,
249- ctx = ast .Load (),
250- ),
251- attr = "linear" ,
252- ctx = ast .Load (),
253- )
254- node .func = new_func
255- return self .generic_visit (node )
256-
257- transformer = LinearReplacer ()
258- modified_tree = transformer .visit (tree )
259- ast .fix_missing_locations (modified_tree )
260-
261- # Convert the modified AST back to code string
262- new_code = ast .unparse (modified_tree )
263-
264- # Recompile the modified code
265- # Need to import device, inf and other modules that may be used
266- namespace = {
267- "torch" : torch ,
268- }
269- # Try to import device (if used in code)
270- try :
271- from torch import device
272-
273- namespace ["device" ] = device
274- except ImportError :
275- pass
276- # Try to import inf (if used in code)
277- try :
278- from torch import inf
207+ node .target = F .linear
279208
280- namespace ["inf" ] = inf
281- except ImportError :
282- # If torch doesn't have inf, use math.inf
283- try :
284- import math
285-
286- namespace ["inf" ] = math .inf
287- except :
288- pass
289-
290- exec (compile (modified_tree , filename = "<ast>" , mode = "exec" ), namespace )
291-
292- # Update GraphModule's forward method
293- forward_func = namespace .get ("forward" )
294- if forward_func :
295- import types
296-
297- gm .forward = types .MethodType (forward_func , gm )
209+ # Recompile the graph
210+ gm .recompile ()
298211
299212 return gm
300213
0 commit comments