Skip to content

Commit 3a3b149

Browse files
committed
feat: convert torch._C._nn.linear to torch.nn.functional.linear
- Implement direct node.target modification for API conversion - Use serialize_graph_module_to_str for API check in check_unstable_api - Add AST-based replacement function (commented) in fx_graph_serialize_util.py - Fix log2json.py to properly initialize result field and map speedup data - Simplify conversion logic by removing complex AST code - Tested with 50 samples: 100% success rate, ES(-6) = 1.013
1 parent c65f7fa commit 3a3b149

File tree

3 files changed

+201
-2
lines changed

3 files changed

+201
-2
lines changed

graph_net/log2json.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,9 @@ def parse_logs_to_json(log_file: str, output_dir: str):
5353
"datatype": {},
5454
"speedup": {},
5555
},
56+
"result": {
57+
"status": "unknown",
58+
},
5659
}
5760
continue
5861

@@ -102,16 +105,20 @@ def parse_logs_to_json(log_file: str, output_dir: str):
102105
result_status_match = patterns["result_status"].search(line)
103106
if result_status_match:
104107
status = result_status_match.group(1).strip()
108+
data["result"]["status"] = status
105109
if status == "failed" and (i + 1) < len(lines):
106110
error_reason_match = patterns["failure"].search(lines[i + 1])
107111
if error_reason_match:
108112
reason = error_reason_match.group(1).lower()
109113
if "eager" in reason:
110114
data["performance"]["failure"] = "eager"
115+
data["result"]["status"] = "eager_fail"
111116
elif "compiled" in reason:
112117
data["performance"]["failure"] = "compiled"
118+
data["result"]["status"] = "compile_fail"
113119
else:
114120
data["performance"]["failure"] = "other"
121+
data["result"]["status"] = "runtime_fail"
115122
continue
116123

117124
speedup_match = patterns["speedup"].search(line)
@@ -141,6 +148,20 @@ def parse_logs_to_json(log_file: str, output_dir: str):
141148
# filename = f"{model_name}_{subgraph_name}_{compiler_name}.json"
142149
filepath = os.path.join(output_dir, filename)
143150

151+
# Build result field with status and speedup
152+
if data["result"]["status"] == "success":
153+
speedup_data = {}
154+
if "e2e" in data["performance"]["speedup"]:
155+
speedup_data["e2e"] = {
156+
"mean": data["performance"]["speedup"]["e2e"]
157+
}
158+
if "gpu" in data["performance"]["speedup"]:
159+
speedup_data["gpu"] = {
160+
"mean": data["performance"]["speedup"]["gpu"]
161+
}
162+
if speedup_data:
163+
data["result"]["speedup"] = speedup_data
164+
144165
with open(filepath, "w", encoding="utf-8") as f:
145166
json.dump(data, f, indent=4)
146167

graph_net/torch/backend/unstable_to_stable_backend.py

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import torch
33
import sys
44
import inspect
5+
import ast
56
from .graph_compiler_backend import GraphCompilerBackend
67
from ..fx_graph_serialize_util import serialize_graph_module_to_str
78

@@ -12,12 +13,21 @@ def __call__(self, model):
1213
unstable_api = os.getenv("DISALLOWED_UNSTABLE_API", "").strip()
1314
self.unstable_api = unstable_api
1415

16+
# Use torch.compile's backend method to get graph module uniformly
17+
# This ensures all models use the same conversion method, avoiding performance differences
1518
def my_backend(gm, sample_inputs):
19+
# Convert unstable API
1620
gm = self.unstable_to_stable(gm)
1721
self.check_unstable_api(gm)
22+
# Return forward function without additional optimization
1823
return gm.forward
1924

20-
return torch.compile(backend=my_backend)(model)
25+
# Use torch.compile to get graph module and perform conversion
26+
# Although compile is used, the backend only does API conversion, no optimization
27+
# Performance should be close to eager mode (since only API replacement is done)
28+
# Note: Do not use mode parameter to avoid version compatibility issues
29+
compiled_model = torch.compile(model, backend=my_backend)
30+
return compiled_model
2131

2232
"""
2333
TODO: Implement logic to convert unstable APIs in `self.model` into their stable counterparts.
@@ -147,6 +157,60 @@ def _impl_unstable_to_stable_special_logit(self, gm):
147157

148158
return gm
149159

160+
def _impl_unstable_to_stable_linear_to_functional_linear(self, gm):
161+
"""
162+
Convert torch._C._nn.linear to torch.nn.functional.linear
163+
164+
Args:
165+
gm: torch.fx.GraphModule object
166+
167+
Returns:
168+
Modified GraphModule object
169+
"""
170+
import torch.nn.functional as F
171+
172+
# Get reference to torch._C._nn.linear for comparison
173+
try:
174+
unstable_linear = torch._C._nn.linear
175+
except AttributeError:
176+
unstable_linear = None
177+
178+
# Traverse all nodes to find nodes that need to be replaced
179+
for node in gm.graph.nodes:
180+
if node.op == "call_function":
181+
target = node.target
182+
should_replace = False
183+
184+
# Method 1: Direct target comparison (most reliable)
185+
if unstable_linear is not None and target is unstable_linear:
186+
should_replace = True
187+
# Method 2: Check if it's the same function object (using id comparison)
188+
elif unstable_linear is not None and id(target) == id(unstable_linear):
189+
should_replace = True
190+
# Method 3: Check module and name attributes (most reliable method, as torch.fx preserves these attributes)
191+
elif hasattr(target, "__module__") and hasattr(target, "__name__"):
192+
if (
193+
target.__module__ == "torch._C._nn"
194+
and target.__name__ == "linear"
195+
):
196+
should_replace = True
197+
# Method 4: Check via string representation (fallback method)
198+
elif "torch._C._nn.linear" in str(target) or (
199+
hasattr(target, "__qualname__")
200+
and "linear" in target.__qualname__
201+
and hasattr(target, "__module__")
202+
and "torch._C._nn" in str(target.__module__)
203+
):
204+
should_replace = True
205+
206+
if should_replace:
207+
node.target = F.linear
208+
209+
# Recompile the graph
210+
gm.recompile()
211+
212+
return gm
213+
150214
# replace this line with modification code for task 116 (torch._C._linalg.linalg_vector_norm)
151215

152216
# replace this line with modification code for task 117 (torch._C._linalg.linalg_norm)

graph_net/torch/fx_graph_serialize_util.py

Lines changed: 115 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,120 @@
22
import torch.fx
33

44

5+
# def apply_ast_based_linear_replacement(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
6+
# """
7+
# Apply AST-based replacement of torch._C._nn.linear to torch.nn.functional.linear.
8+
#
9+
# This function uses AST parsing and transformation to replace torch._C._nn.linear
10+
# calls with torch.nn.functional.linear in the GraphModule's code.
11+
#
12+
# Note: This function is currently commented out as the replacement is now handled
13+
# by simple string replacement in serialize_graph_module_to_str.
14+
#
15+
# Args:
16+
# gm: The GraphModule to modify.
17+
#
18+
# Returns:
19+
# Modified GraphModule with torch._C._nn.linear replaced by torch.nn.functional.linear.
20+
# """
21+
# import ast
22+
# import torch
23+
# import types
24+
#
25+
# # First recompile to generate code
26+
# gm.recompile()
27+
#
28+
# # Use AST to modify the generated code, replacing torch._C._nn.linear with torch.nn.functional.linear
29+
# code_str = gm.code
30+
#
31+
# # Parse AST
32+
# tree = ast.parse(code_str)
33+
#
34+
# class LinearReplacer(ast.NodeTransformer):
35+
# def visit_Call(self, node):
36+
# # Check if it's a torch._C._nn.linear call
37+
# # Structure: torch._C._nn.linear(...)
38+
# filtered_nodes = [
39+
# node
40+
# for node in [node]
41+
# if isinstance(node.func, ast.Attribute)
42+
# if node.func.attr == "linear"
43+
# if isinstance(node.func.value, ast.Attribute)
44+
# if node.func.value.attr == "_nn"
45+
# if isinstance(node.func.value.value, ast.Attribute)
46+
# if node.func.value.value.attr == "_C"
47+
# if isinstance(node.func.value.value.value, ast.Name)
48+
# if node.func.value.value.value.id == "torch"
49+
# ]
50+
# if filtered_nodes:
51+
# # Found torch._C._nn.linear, replace with torch.nn.functional.linear
52+
# new_func = ast.Attribute(
53+
# value=ast.Attribute(
54+
# value=ast.Attribute(
55+
# value=ast.Name(
56+
# id="torch",
57+
# ctx=ast.Load(),
58+
# ),
59+
# attr="nn",
60+
# ctx=ast.Load(),
61+
# ),
62+
# attr="functional",
63+
# ctx=ast.Load(),
64+
# ),
65+
# attr="linear",
66+
# ctx=ast.Load(),
67+
# )
68+
# node.func = new_func
69+
# return self.generic_visit(node)
70+
#
71+
# transformer = LinearReplacer()
72+
# modified_tree = transformer.visit(tree)
73+
# ast.fix_missing_locations(modified_tree)
74+
#
75+
# # Convert the modified AST back to code string
76+
# new_code = ast.unparse(modified_tree)
77+
#
78+
# # Recompile the modified code
79+
# # Need to import device, inf and other modules that may be used
80+
# namespace = {
81+
# "torch": torch,
82+
# }
83+
# # Try to import device (if used in code)
84+
# try:
85+
# from torch import device
86+
#
87+
# namespace["device"] = device
88+
# except ImportError:
89+
# pass
90+
# # Try to import inf (if used in code)
91+
# try:
92+
# from torch import inf
93+
#
94+
# namespace["inf"] = inf
95+
# except ImportError:
96+
# # If torch doesn't have inf, use math.inf
97+
# try:
98+
# import math
99+
#
100+
# namespace["inf"] = math.inf
101+
# except:
102+
# pass
103+
#
104+
# exec(compile(modified_tree, filename="<ast>", mode="exec"), namespace)
105+
#
106+
# # Update GraphModule's forward method
107+
# forward_func = namespace.get("forward")
108+
# if forward_func:
109+
# gm.forward = types.MethodType(forward_func, gm)
110+
#
111+
# # Use serialize_graph_module_to_str to get the serialized code
112+
# # This ensures the code is properly serialized with unstable API replacements
113+
# serialized_code = serialize_graph_module_to_str(gm)
114+
# gm._code = serialized_code
115+
#
116+
# return gm
117+
118+
5119
def serialize_graph_module_to_str(gm: torch.fx.GraphModule) -> str:
6120
"""
7121
Serialize a GraphModule to a string representation, replacing unstable APIs
@@ -34,7 +148,7 @@ def serialize_graph_module_to_str(gm: torch.fx.GraphModule) -> str:
34148
# replace this line with modification code for task 123 (torch._C._nn.pad)
35149
# replace this line with modification code for task 125 (torch._C._nn.gelu)
36150
# replace this line with modification code for task 126 (torch._C._nn.scaled_dot_product_attention)
37-
# replace this line with modification code for task 127 (torch._C._nn.linear)
151+
(r"torch\._C\._nn\.linear\(", "torch.nn.functional.linear("),
38152
]
39153
for pattern, repl in replacements:
40154
code = re.sub(pattern, repl, code)

0 commit comments

Comments
 (0)