Skip to content

Commit d3060c0

Browse files
committed
refactor: use serialize_graph_module_to_str for API check and add AST-based replacement function (commented)
1 parent ed19d5d commit d3060c0

File tree

2 files changed

+117
-3
lines changed

2 files changed

+117
-3
lines changed

graph_net/torch/backend/unstable_to_stable_backend.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -344,8 +344,8 @@ def check_unstable_api(self, gm):
344344
Do NOT modify, remove, or bypass this check under any circumstances.
345345
"""
346346

347-
# Use code to check for unstable APIs
348-
graph_text = gm.code
347+
# Use serialized code to check for unstable APIs
348+
graph_text = serialize_graph_module_to_str(gm)
349349
# Search for the unstable API substring
350350
if self.unstable_api in graph_text:
351351
count = graph_text.count(self.unstable_api)

graph_net/torch/fx_graph_serialize_util.py

Lines changed: 115 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,120 @@
22
import torch.fx
33

44

5+
# def apply_ast_based_linear_replacement(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
6+
# """
7+
# Apply AST-based replacement of torch._C._nn.linear to torch.nn.functional.linear.
8+
#
9+
# This function uses AST parsing and transformation to replace torch._C._nn.linear
10+
# calls with torch.nn.functional.linear in the GraphModule's code.
11+
#
12+
# Note: This function is currently commented out as the replacement is now handled
13+
# by simple string replacement in serialize_graph_module_to_str.
14+
#
15+
# Args:
16+
# gm: The GraphModule to modify.
17+
#
18+
# Returns:
19+
# Modified GraphModule with torch._C._nn.linear replaced by torch.nn.functional.linear.
20+
# """
21+
# import ast
22+
# import torch
23+
# import types
24+
#
25+
# # First recompile to generate code
26+
# gm.recompile()
27+
#
28+
# # Use AST to modify the generated code, replacing torch._C._nn.linear with torch.nn.functional.linear
29+
# code_str = gm.code
30+
#
31+
# # Parse AST
32+
# tree = ast.parse(code_str)
33+
#
34+
# class LinearReplacer(ast.NodeTransformer):
35+
# def visit_Call(self, node):
36+
# # Check if it's a torch._C._nn.linear call
37+
# # Structure: torch._C._nn.linear(...)
38+
# filtered_nodes = [
39+
# node
40+
# for node in [node]
41+
# if isinstance(node.func, ast.Attribute)
42+
# if node.func.attr == "linear"
43+
# if isinstance(node.func.value, ast.Attribute)
44+
# if node.func.value.attr == "_nn"
45+
# if isinstance(node.func.value.value, ast.Attribute)
46+
# if node.func.value.value.attr == "_C"
47+
# if isinstance(node.func.value.value.value, ast.Name)
48+
# if node.func.value.value.value.id == "torch"
49+
# ]
50+
# if filtered_nodes:
51+
# # Found torch._C._nn.linear, replace with torch.nn.functional.linear
52+
# new_func = ast.Attribute(
53+
# value=ast.Attribute(
54+
# value=ast.Attribute(
55+
# value=ast.Name(
56+
# id="torch",
57+
# ctx=ast.Load(),
58+
# ),
59+
# attr="nn",
60+
# ctx=ast.Load(),
61+
# ),
62+
# attr="functional",
63+
# ctx=ast.Load(),
64+
# ),
65+
# attr="linear",
66+
# ctx=ast.Load(),
67+
# )
68+
# node.func = new_func
69+
# return self.generic_visit(node)
70+
#
71+
# transformer = LinearReplacer()
72+
# modified_tree = transformer.visit(tree)
73+
# ast.fix_missing_locations(modified_tree)
74+
#
75+
# # Convert the modified AST back to code string
76+
# new_code = ast.unparse(modified_tree)
77+
#
78+
# # Recompile the modified code
79+
# # Need to import device, inf and other modules that may be used
80+
# namespace = {
81+
# "torch": torch,
82+
# }
83+
# # Try to import device (if used in code)
84+
# try:
85+
# from torch import device
86+
#
87+
# namespace["device"] = device
88+
# except ImportError:
89+
# pass
90+
# # Try to import inf (if used in code)
91+
# try:
92+
# from torch import inf
93+
#
94+
# namespace["inf"] = inf
95+
# except ImportError:
96+
# # If torch doesn't have inf, use math.inf
97+
# try:
98+
# import math
99+
#
100+
# namespace["inf"] = math.inf
101+
# except:
102+
# pass
103+
#
104+
# exec(compile(modified_tree, filename="<ast>", mode="exec"), namespace)
105+
#
106+
# # Update GraphModule's forward method
107+
# forward_func = namespace.get("forward")
108+
# if forward_func:
109+
# gm.forward = types.MethodType(forward_func, gm)
110+
#
111+
# # Use serialize_graph_module_to_str to get the serialized code
112+
# # This ensures the code is properly serialized with unstable API replacements
113+
# serialized_code = serialize_graph_module_to_str(gm)
114+
# gm._code = serialized_code
115+
#
116+
# return gm
117+
118+
5119
def serialize_graph_module_to_str(gm: torch.fx.GraphModule) -> str:
6120
"""
7121
Serialize a GraphModule to a string representation, replacing unstable APIs
@@ -25,7 +139,6 @@ def serialize_graph_module_to_str(gm: torch.fx.GraphModule) -> str:
25139
(r"torch\._C\._fft\.fft_rfft\(", "torch.fft.rfft("),
26140
(r"torch\._C\._fft\.fft_fftn\(", "torch.fft.fftn("),
27141
(r"torch\._C\._special\.special_logit\(", "torch.special.logit("),
28-
(r"torch\._C\._nn\.linear\(", "torch.nn.functional.linear("),
29142
# replace this line with modification code for task 116 (torch._C._linalg.linalg_vector_norm)
30143
# replace this line with modification code for task 117 (torch._C._linalg.linalg_norm)
31144
# replace this line with modification code for task 118 (torch._C._nn.softplus)
@@ -35,6 +148,7 @@ def serialize_graph_module_to_str(gm: torch.fx.GraphModule) -> str:
35148
# replace this line with modification code for task 123 (torch._C._nn.pad)
36149
# replace this line with modification code for task 125 (torch._C._nn.gelu)
37150
# replace this line with modification code for task 126 (torch._C._nn.scaled_dot_product_attention)
151+
(r"torch\._C\._nn\.linear\(", "torch.nn.functional.linear("),
38152
]
39153
for pattern, repl in replacements:
40154
code = re.sub(pattern, repl, code)

0 commit comments

Comments
 (0)