Skip to content

Commit 4e56dbb

Browse files
committed
merge code
2 parents 3c1c196 + 7a0717a commit 4e56dbb

File tree

3 files changed

+221
-6
lines changed

3 files changed

+221
-6
lines changed

graph_net/log2json.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,9 @@ def parse_logs_to_json(log_file: str, output_dir: str):
5353
"datatype": {},
5454
"speedup": {},
5555
},
56+
"result": {
57+
"status": "unknown",
58+
},
5659
}
5760
continue
5861

@@ -102,16 +105,20 @@ def parse_logs_to_json(log_file: str, output_dir: str):
102105
result_status_match = patterns["result_status"].search(line)
103106
if result_status_match:
104107
status = result_status_match.group(1).strip()
108+
data["result"]["status"] = status
105109
if status == "failed" and (i + 1) < len(lines):
106110
error_reason_match = patterns["failure"].search(lines[i + 1])
107111
if error_reason_match:
108112
reason = error_reason_match.group(1).lower()
109113
if "eager" in reason:
110114
data["performance"]["failure"] = "eager"
115+
data["result"]["status"] = "eager_fail"
111116
elif "compiled" in reason:
112117
data["performance"]["failure"] = "compiled"
118+
data["result"]["status"] = "compile_fail"
113119
else:
114120
data["performance"]["failure"] = "other"
121+
data["result"]["status"] = "runtime_fail"
115122
continue
116123

117124
speedup_match = patterns["speedup"].search(line)
@@ -141,6 +148,20 @@ def parse_logs_to_json(log_file: str, output_dir: str):
141148
# filename = f"{model_name}_{subgraph_name}_{compiler_name}.json"
142149
filepath = os.path.join(output_dir, filename)
143150

151+
# Build result field with status and speedup
152+
if data["result"]["status"] == "success":
153+
speedup_data = {}
154+
if "e2e" in data["performance"]["speedup"]:
155+
speedup_data["e2e"] = {
156+
"mean": data["performance"]["speedup"]["e2e"]
157+
}
158+
if "gpu" in data["performance"]["speedup"]:
159+
speedup_data["gpu"] = {
160+
"mean": data["performance"]["speedup"]["gpu"]
161+
}
162+
if speedup_data:
163+
data["result"]["speedup"] = speedup_data
164+
144165
with open(filepath, "w", encoding="utf-8") as f:
145166
json.dump(data, f, indent=4)
146167

graph_net/torch/backend/unstable_to_stable_backend.py

Lines changed: 84 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import torch
33
import sys
44
import inspect
5+
import ast
56
from .graph_compiler_backend import GraphCompilerBackend
67
from ..fx_graph_serialize_util import serialize_graph_module_to_str
78

@@ -12,12 +13,21 @@ def __call__(self, model):
1213
unstable_api = os.getenv("DISALLOWED_UNSTABLE_API", "").strip()
1314
self.unstable_api = unstable_api
1415

16+
# Use torch.compile's backend method to get graph module uniformly
17+
# This ensures all models use the same conversion method, avoiding performance differences
1518
def my_backend(gm, sample_inputs):
19+
# Convert unstable API
1620
gm = self.unstable_to_stable(gm)
1721
self.check_unstable_api(gm)
22+
# Return forward function without additional optimization
1823
return gm.forward
1924

20-
return torch.compile(backend=my_backend)(model)
25+
# Use torch.compile to get graph module and perform conversion
26+
# Although compile is used, the backend only does API conversion, no optimization
27+
# Performance should be close to eager mode (since only API replacement is done)
28+
# Note: Do not use mode parameter to avoid version compatibility issues
29+
compiled_model = torch.compile(model, backend=my_backend)
30+
return compiled_model
2131

2232
"""
2333
TODO: Implement logic to convert unstable APIs in `self.model` into their stable counterparts.
@@ -144,14 +154,32 @@ def _impl_unstable_to_stable_special_logit(self, gm):
144154

145155
# Recompile the graph
146156
gm.recompile()
147-
148157
return gm
149158

150159
# replace this line with modification code for task 116 (torch._C._linalg.linalg_vector_norm)
151160

152161
# replace this line with modification code for task 117 (torch._C._linalg.linalg_norm)
153162

154-
# replace this line with modification code for task 118 (torch._C._nn.softplus)
163+
def _impl_unstable_to_stable_softplus(self, gm):
164+
"""
165+
Convert torch._C._nn.softplus to torch.nn.functional.softplus
166+
"""
167+
import torch.nn.functional as F
168+
169+
issue_nodes = (
170+
node
171+
for node in gm.graph.nodes
172+
if node.op == "call_function"
173+
if hasattr(node.target, "__module__")
174+
if node.target.__module__ == "torch._C._nn"
175+
if hasattr(node.target, "__name__")
176+
if node.target.__name__ == "softplus"
177+
)
178+
for node in issue_nodes:
179+
node.target = F.softplus
180+
181+
gm.recompile()
182+
return gm
155183

156184
def _impl_unstable_to_stable_one_hot(self, gm):
157185
"""
@@ -186,7 +214,59 @@ def _impl_unstable_to_stable_one_hot(self, gm):
186214

187215
# replace this line with modification code for task 126 (torch._C._nn.scaled_dot_product_attention)
188216

189-
# replace this line with modification code for task 127 (torch._C._nn.linear)
217+
def _impl_unstable_to_stable_linear_to_functional_linear(self, gm):
218+
"""
219+
Convert torch._C._nn.linear to torch.nn.functional.linear
220+
221+
Args:
222+
gm: torch.fx.GraphModule object
223+
224+
Returns:
225+
Modified GraphModule object
226+
"""
227+
import torch.nn.functional as F
228+
229+
# Get reference to torch._C._nn.linear for comparison
230+
try:
231+
unstable_linear = torch._C._nn.linear
232+
except AttributeError:
233+
unstable_linear = None
234+
235+
# Traverse all nodes to find nodes that need to be replaced
236+
for node in gm.graph.nodes:
237+
if node.op == "call_function":
238+
target = node.target
239+
should_replace = False
240+
241+
# Method 1: Direct target comparison (most reliable)
242+
if unstable_linear is not None and target is unstable_linear:
243+
should_replace = True
244+
# Method 2: Check if it's the same function object (using id comparison)
245+
elif unstable_linear is not None and id(target) == id(unstable_linear):
246+
should_replace = True
247+
# Method 3: Check module and name attributes (most reliable method, as torch.fx preserves these attributes)
248+
elif hasattr(target, "__module__") and hasattr(target, "__name__"):
249+
if (
250+
target.__module__ == "torch._C._nn"
251+
and target.__name__ == "linear"
252+
):
253+
should_replace = True
254+
# Method 4: Check via string representation (fallback method)
255+
elif "torch._C._nn.linear" in str(target) or (
256+
hasattr(target, "__qualname__")
257+
and "linear" in target.__qualname__
258+
and hasattr(target, "__module__")
259+
and "torch._C._nn" in str(target.__module__)
260+
):
261+
should_replace = True
262+
263+
if should_replace:
264+
node.target = F.linear
265+
266+
# Recompile the graph
267+
gm.recompile()
268+
269+
return gm
190270

191271
def unstable_to_stable(self, gm):
192272
methods = (

graph_net/torch/fx_graph_serialize_util.py

Lines changed: 116 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,120 @@
22
import torch.fx
33

44

5+
# def apply_ast_based_linear_replacement(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
6+
# """
7+
# Apply AST-based replacement of torch._C._nn.linear to torch.nn.functional.linear.
8+
#
9+
# This function uses AST parsing and transformation to replace torch._C._nn.linear
10+
# calls with torch.nn.functional.linear in the GraphModule's code.
11+
#
12+
# Note: This function is currently commented out as the replacement is now handled
13+
# by simple string replacement in serialize_graph_module_to_str.
14+
#
15+
# Args:
16+
# gm: The GraphModule to modify.
17+
#
18+
# Returns:
19+
# Modified GraphModule with torch._C._nn.linear replaced by torch.nn.functional.linear.
20+
# """
21+
# import ast
22+
# import torch
23+
# import types
24+
#
25+
# # First recompile to generate code
26+
# gm.recompile()
27+
#
28+
# # Use AST to modify the generated code, replacing torch._C._nn.linear with torch.nn.functional.linear
29+
# code_str = gm.code
30+
#
31+
# # Parse AST
32+
# tree = ast.parse(code_str)
33+
#
34+
# class LinearReplacer(ast.NodeTransformer):
35+
# def visit_Call(self, node):
36+
# # Check if it's a torch._C._nn.linear call
37+
# # Structure: torch._C._nn.linear(...)
38+
# filtered_nodes = [
39+
# node
40+
# for node in [node]
41+
# if isinstance(node.func, ast.Attribute)
42+
# if node.func.attr == "linear"
43+
# if isinstance(node.func.value, ast.Attribute)
44+
# if node.func.value.attr == "_nn"
45+
# if isinstance(node.func.value.value, ast.Attribute)
46+
# if node.func.value.value.attr == "_C"
47+
# if isinstance(node.func.value.value.value, ast.Name)
48+
# if node.func.value.value.value.id == "torch"
49+
# ]
50+
# if filtered_nodes:
51+
# # Found torch._C._nn.linear, replace with torch.nn.functional.linear
52+
# new_func = ast.Attribute(
53+
# value=ast.Attribute(
54+
# value=ast.Attribute(
55+
# value=ast.Name(
56+
# id="torch",
57+
# ctx=ast.Load(),
58+
# ),
59+
# attr="nn",
60+
# ctx=ast.Load(),
61+
# ),
62+
# attr="functional",
63+
# ctx=ast.Load(),
64+
# ),
65+
# attr="linear",
66+
# ctx=ast.Load(),
67+
# )
68+
# node.func = new_func
69+
# return self.generic_visit(node)
70+
#
71+
# transformer = LinearReplacer()
72+
# modified_tree = transformer.visit(tree)
73+
# ast.fix_missing_locations(modified_tree)
74+
#
75+
# # Convert the modified AST back to code string
76+
# new_code = ast.unparse(modified_tree)
77+
#
78+
# # Recompile the modified code
79+
# # Need to import device, inf and other modules that may be used
80+
# namespace = {
81+
# "torch": torch,
82+
# }
83+
# # Try to import device (if used in code)
84+
# try:
85+
# from torch import device
86+
#
87+
# namespace["device"] = device
88+
# except ImportError:
89+
# pass
90+
# # Try to import inf (if used in code)
91+
# try:
92+
# from torch import inf
93+
#
94+
# namespace["inf"] = inf
95+
# except ImportError:
96+
# # If torch doesn't have inf, use math.inf
97+
# try:
98+
# import math
99+
#
100+
# namespace["inf"] = math.inf
101+
# except:
102+
# pass
103+
#
104+
# exec(compile(modified_tree, filename="<ast>", mode="exec"), namespace)
105+
#
106+
# # Update GraphModule's forward method
107+
# forward_func = namespace.get("forward")
108+
# if forward_func:
109+
# gm.forward = types.MethodType(forward_func, gm)
110+
#
111+
# # Use serialize_graph_module_to_str to get the serialized code
112+
# # This ensures the code is properly serialized with unstable API replacements
113+
# serialized_code = serialize_graph_module_to_str(gm)
114+
# gm._code = serialized_code
115+
#
116+
# return gm
117+
118+
5119
def serialize_graph_module_to_str(gm: torch.fx.GraphModule) -> str:
6120
"""
7121
Serialize a GraphModule to a string representation, replacing unstable APIs
@@ -27,14 +141,14 @@ def serialize_graph_module_to_str(gm: torch.fx.GraphModule) -> str:
27141
(r"torch\._C\._special\.special_logit\(", "torch.special.logit("),
28142
# replace this line with modification code for task 116 (torch._C._linalg.linalg_vector_norm)
29143
# replace this line with modification code for task 117 (torch._C._linalg.linalg_norm)
30-
# replace this line with modification code for task 118 (torch._C._nn.softplus)
144+
(r"torch\._C\._nn\.softplus\(", "torch.nn.functional.softplus("),
31145
(r"torch\._C\._nn\.one_hot\(", "torch.nn.functional.one_hot("),
32146
# replace this line with modification code for task 121 (torch._C._set_grad_enabled)
33147
# replace this line with modification code for task 122 (torch._C._log_api_usage_once)
34148
# replace this line with modification code for task 123 (torch._C._nn.pad)
35149
# replace this line with modification code for task 125 (torch._C._nn.gelu)
36150
# replace this line with modification code for task 126 (torch._C._nn.scaled_dot_product_attention)
37-
# replace this line with modification code for task 127 (torch._C._nn.linear)
151+
(r"torch\._C\._nn\.linear\(", "torch.nn.functional.linear("),
38152
]
39153
for pattern, repl in replacements:
40154
code = re.sub(pattern, repl, code)

0 commit comments

Comments
 (0)