Skip to content

Commit 1a77a0e

Browse files
committed
feat: convert torch._C._nn.linear to torch.nn.functional.linear
- Implement direct node.target modification for API conversion - Use serialize_graph_module_to_str for API check - Add AST-based replacement function (commented) in fx_graph_serialize_util.py - Simplify conversion logic by removing complex AST code - Tested with 50 samples: 100% success rate, ES(-6) = 1.013
1 parent 2116c66 commit 1a77a0e

File tree

2 files changed

+120
-103
lines changed

2 files changed

+120
-103
lines changed

graph_net/torch/backend/unstable_to_stable_backend.py

Lines changed: 5 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -167,14 +167,15 @@ def _impl_unstable_to_stable_linear_to_functional_linear(self, gm):
167167
Returns:
168168
Modified GraphModule object
169169
"""
170+
import torch.nn.functional as F
171+
170172
# Get reference to torch._C._nn.linear for comparison
171173
try:
172174
unstable_linear = torch._C._nn.linear
173175
except AttributeError:
174176
unstable_linear = None
175177

176178
# Traverse all nodes to find nodes that need to be replaced
177-
nodes_to_replace = []
178179
for node in gm.graph.nodes:
179180
if node.op == "call_function":
180181
target = node.target
@@ -203,108 +204,10 @@ def _impl_unstable_to_stable_linear_to_functional_linear(self, gm):
203204
should_replace = True
204205

205206
if should_replace:
206-
nodes_to_replace.append(node)
207-
208-
# Since torch._C._nn.linear and torch.nn.functional.linear are the same object,
209-
# the code generator cannot distinguish them, so we need to use AST to modify the code string after code generation
210-
if nodes_to_replace:
211-
# First recompile to generate code
212-
gm.recompile()
213-
214-
# Use AST to modify the generated code, replacing torch._C._nn.linear with torch.nn.functional.linear
215-
code_str = gm.code
216-
217-
# Parse AST
218-
tree = ast.parse(code_str)
219-
220-
class LinearReplacer(ast.NodeTransformer):
221-
def visit_Call(self, node):
222-
# Check if it's a torch._C._nn.linear call
223-
# Structure: torch._C._nn.linear(...)
224-
if isinstance(node.func, ast.Attribute):
225-
# node.func.attr should be "linear"
226-
if node.func.attr == "linear":
227-
# node.func.value should be torch._C._nn
228-
if isinstance(node.func.value, ast.Attribute):
229-
# node.func.value.attr should be "_nn"
230-
if node.func.value.attr == "_nn":
231-
# node.func.value.value should be torch._C
232-
if isinstance(node.func.value.value, ast.Attribute):
233-
# node.func.value.value.attr should be "_C"
234-
if node.func.value.value.attr == "_C":
235-
# node.func.value.value.value should be torch
236-
if (
237-
isinstance(
238-
node.func.value.value.value,
239-
ast.Name,
240-
)
241-
and node.func.value.value.value.id
242-
== "torch"
243-
):
244-
# Found torch._C._nn.linear, replace with torch.nn.functional.linear
245-
new_func = ast.Attribute(
246-
value=ast.Attribute(
247-
value=ast.Attribute(
248-
value=ast.Name(
249-
id="torch",
250-
ctx=ast.Load(),
251-
),
252-
attr="nn",
253-
ctx=ast.Load(),
254-
),
255-
attr="functional",
256-
ctx=ast.Load(),
257-
),
258-
attr="linear",
259-
ctx=ast.Load(),
260-
)
261-
node.func = new_func
262-
return self.generic_visit(node)
263-
264-
transformer = LinearReplacer()
265-
modified_tree = transformer.visit(tree)
266-
ast.fix_missing_locations(modified_tree)
267-
268-
# Convert the modified AST back to code string
269-
new_code = ast.unparse(modified_tree)
270-
271-
# Recompile the modified code
272-
# Need to import device, inf and other modules that may be used
273-
namespace = {
274-
"torch": torch,
275-
}
276-
# Try to import device (if used in code)
277-
try:
278-
from torch import device
279-
280-
namespace["device"] = device
281-
except ImportError:
282-
pass
283-
# Try to import inf (if used in code)
284-
try:
285-
from torch import inf
286-
287-
namespace["inf"] = inf
288-
except ImportError:
289-
# If torch doesn't have inf, use math.inf
290-
try:
291-
import math
207+
node.target = F.linear
292208

293-
namespace["inf"] = math.inf
294-
except:
295-
pass
296-
297-
exec(compile(modified_tree, filename="<ast>", mode="exec"), namespace)
298-
299-
# Update GraphModule's forward method and code
300-
forward_func = namespace.get("forward")
301-
if forward_func:
302-
import types
303-
304-
gm.forward = types.MethodType(forward_func, gm)
305-
306-
# Update _code attribute so that gm.code returns the modified code
307-
gm._code = new_code
209+
# Recompile the graph
210+
gm.recompile()
308211

309212
return gm
310213

graph_net/torch/fx_graph_serialize_util.py

Lines changed: 115 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,120 @@
22
import torch.fx
33

44

5+
# def apply_ast_based_linear_replacement(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
6+
# """
7+
# Apply AST-based replacement of torch._C._nn.linear to torch.nn.functional.linear.
8+
#
9+
# This function uses AST parsing and transformation to replace torch._C._nn.linear
10+
# calls with torch.nn.functional.linear in the GraphModule's code.
11+
#
12+
# Note: This function is currently commented out as the replacement is now handled
13+
# by simple string replacement in serialize_graph_module_to_str.
14+
#
15+
# Args:
16+
# gm: The GraphModule to modify.
17+
#
18+
# Returns:
19+
# Modified GraphModule with torch._C._nn.linear replaced by torch.nn.functional.linear.
20+
# """
21+
# import ast
22+
# import torch
23+
# import types
24+
#
25+
# # First recompile to generate code
26+
# gm.recompile()
27+
#
28+
# # Use AST to modify the generated code, replacing torch._C._nn.linear with torch.nn.functional.linear
29+
# code_str = gm.code
30+
#
31+
# # Parse AST
32+
# tree = ast.parse(code_str)
33+
#
34+
# class LinearReplacer(ast.NodeTransformer):
35+
# def visit_Call(self, node):
36+
# # Check if it's a torch._C._nn.linear call
37+
# # Structure: torch._C._nn.linear(...)
38+
# filtered_nodes = [
39+
# node
40+
# for node in [node]
41+
# if isinstance(node.func, ast.Attribute)
42+
# if node.func.attr == "linear"
43+
# if isinstance(node.func.value, ast.Attribute)
44+
# if node.func.value.attr == "_nn"
45+
# if isinstance(node.func.value.value, ast.Attribute)
46+
# if node.func.value.value.attr == "_C"
47+
# if isinstance(node.func.value.value.value, ast.Name)
48+
# if node.func.value.value.value.id == "torch"
49+
# ]
50+
# if filtered_nodes:
51+
# # Found torch._C._nn.linear, replace with torch.nn.functional.linear
52+
# new_func = ast.Attribute(
53+
# value=ast.Attribute(
54+
# value=ast.Attribute(
55+
# value=ast.Name(
56+
# id="torch",
57+
# ctx=ast.Load(),
58+
# ),
59+
# attr="nn",
60+
# ctx=ast.Load(),
61+
# ),
62+
# attr="functional",
63+
# ctx=ast.Load(),
64+
# ),
65+
# attr="linear",
66+
# ctx=ast.Load(),
67+
# )
68+
# node.func = new_func
69+
# return self.generic_visit(node)
70+
#
71+
# transformer = LinearReplacer()
72+
# modified_tree = transformer.visit(tree)
73+
# ast.fix_missing_locations(modified_tree)
74+
#
75+
# # Convert the modified AST back to code string
76+
# new_code = ast.unparse(modified_tree)
77+
#
78+
# # Recompile the modified code
79+
# # Need to import device, inf and other modules that may be used
80+
# namespace = {
81+
# "torch": torch,
82+
# }
83+
# # Try to import device (if used in code)
84+
# try:
85+
# from torch import device
86+
#
87+
# namespace["device"] = device
88+
# except ImportError:
89+
# pass
90+
# # Try to import inf (if used in code)
91+
# try:
92+
# from torch import inf
93+
#
94+
# namespace["inf"] = inf
95+
# except ImportError:
96+
# # If torch doesn't have inf, use math.inf
97+
# try:
98+
# import math
99+
#
100+
# namespace["inf"] = math.inf
101+
# except:
102+
# pass
103+
#
104+
# exec(compile(modified_tree, filename="<ast>", mode="exec"), namespace)
105+
#
106+
# # Update GraphModule's forward method
107+
# forward_func = namespace.get("forward")
108+
# if forward_func:
109+
# gm.forward = types.MethodType(forward_func, gm)
110+
#
111+
# # Use serialize_graph_module_to_str to get the serialized code
112+
# # This ensures the code is properly serialized with unstable API replacements
113+
# serialized_code = serialize_graph_module_to_str(gm)
114+
# gm._code = serialized_code
115+
#
116+
# return gm
117+
118+
5119
def serialize_graph_module_to_str(gm: torch.fx.GraphModule) -> str:
6120
"""
7121
Serialize a GraphModule to a string representation, replacing unstable APIs
@@ -34,7 +148,7 @@ def serialize_graph_module_to_str(gm: torch.fx.GraphModule) -> str:
34148
# replace this line with modification code for task 123 (torch._C._nn.pad)
35149
# replace this line with modification code for task 125 (torch._C._nn.gelu)
36150
# replace this line with modification code for task 126 (torch._C._nn.scaled_dot_product_attention)
37-
# replace this line with modification code for task 127 (torch._C._nn.linear)
151+
(r"torch\._C\._nn\.linear\(", "torch.nn.functional.linear("),
38152
]
39153
for pattern, repl in replacements:
40154
code = re.sub(pattern, repl, code)

0 commit comments

Comments
 (0)