Skip to content

Commit d49878a

Browse files
committed
refactor: simplify linear to functional linear conversion by removing AST code
- Remove complex AST-based code replacement (92 lines) - Use direct node.target modification instead - Tested with 50 samples: 100% success rate, ES(-6) = 0.993
1 parent 6c62d31 commit d49878a

File tree

1 file changed

+5
-92
lines changed

1 file changed

+5
-92
lines changed

graph_net/torch/backend/unstable_to_stable_backend.py

Lines changed: 5 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -167,14 +167,15 @@ def _impl_unstable_to_stable_linear_to_functional_linear(self, gm):
167167
Returns:
168168
Modified GraphModule object
169169
"""
170+
import torch.nn.functional as F
171+
170172
# Get reference to torch._C._nn.linear for comparison
171173
try:
172174
unstable_linear = torch._C._nn.linear
173175
except AttributeError:
174176
unstable_linear = None
175177

176178
# Traverse all nodes to find nodes that need to be replaced
177-
nodes_to_replace = []
178179
for node in gm.graph.nodes:
179180
if node.op == "call_function":
180181
target = node.target
@@ -203,98 +204,10 @@ def _impl_unstable_to_stable_linear_to_functional_linear(self, gm):
203204
should_replace = True
204205

205206
if should_replace:
206-
nodes_to_replace.append(node)
207-
208-
# Since torch._C._nn.linear and torch.nn.functional.linear are the same object,
209-
# the code generator cannot distinguish them, so we need to use AST to modify the code string after code generation
210-
if nodes_to_replace:
211-
# First recompile to generate code
212-
gm.recompile()
213-
214-
# Use AST to modify the generated code, replacing torch._C._nn.linear with torch.nn.functional.linear
215-
code_str = gm.code
216-
217-
# Parse AST
218-
tree = ast.parse(code_str)
219-
220-
class LinearReplacer(ast.NodeTransformer):
221-
def visit_Call(self, node):
222-
# Check if it's a torch._C._nn.linear call
223-
# Structure: torch._C._nn.linear(...)
224-
filtered_nodes = [
225-
node
226-
for node in [node]
227-
if isinstance(node.func, ast.Attribute)
228-
if node.func.attr == "linear"
229-
if isinstance(node.func.value, ast.Attribute)
230-
if node.func.value.attr == "_nn"
231-
if isinstance(node.func.value.value, ast.Attribute)
232-
if node.func.value.value.attr == "_C"
233-
if isinstance(node.func.value.value.value, ast.Name)
234-
if node.func.value.value.value.id == "torch"
235-
]
236-
if filtered_nodes:
237-
# Found torch._C._nn.linear, replace with torch.nn.functional.linear
238-
new_func = ast.Attribute(
239-
value=ast.Attribute(
240-
value=ast.Attribute(
241-
value=ast.Name(
242-
id="torch",
243-
ctx=ast.Load(),
244-
),
245-
attr="nn",
246-
ctx=ast.Load(),
247-
),
248-
attr="functional",
249-
ctx=ast.Load(),
250-
),
251-
attr="linear",
252-
ctx=ast.Load(),
253-
)
254-
node.func = new_func
255-
return self.generic_visit(node)
256-
257-
transformer = LinearReplacer()
258-
modified_tree = transformer.visit(tree)
259-
ast.fix_missing_locations(modified_tree)
260-
261-
# Convert the modified AST back to code string
262-
new_code = ast.unparse(modified_tree)
263-
264-
# Recompile the modified code
265-
# Need to import device, inf and other modules that may be used
266-
namespace = {
267-
"torch": torch,
268-
}
269-
# Try to import device (if used in code)
270-
try:
271-
from torch import device
272-
273-
namespace["device"] = device
274-
except ImportError:
275-
pass
276-
# Try to import inf (if used in code)
277-
try:
278-
from torch import inf
207+
node.target = F.linear
279208

280-
namespace["inf"] = inf
281-
except ImportError:
282-
# If torch doesn't have inf, use math.inf
283-
try:
284-
import math
285-
286-
namespace["inf"] = math.inf
287-
except:
288-
pass
289-
290-
exec(compile(modified_tree, filename="<ast>", mode="exec"), namespace)
291-
292-
# Update GraphModule's forward method
293-
forward_func = namespace.get("forward")
294-
if forward_func:
295-
import types
296-
297-
gm.forward = types.MethodType(forward_func, gm)
209+
# Recompile the graph
210+
gm.recompile()
298211

299212
return gm
300213

0 commit comments

Comments
 (0)