22import torch .fx
33
44
5+ # def apply_ast_based_linear_replacement(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
6+ # """
7+ # Apply AST-based replacement of torch._C._nn.linear to torch.nn.functional.linear.
8+ #
9+ # This function uses AST parsing and transformation to replace torch._C._nn.linear
10+ # calls with torch.nn.functional.linear in the GraphModule's code.
11+ #
12+ # Note: This function is currently commented out as the replacement is now handled
13+ # by simple string replacement in serialize_graph_module_to_str.
14+ #
15+ # Args:
16+ # gm: The GraphModule to modify.
17+ #
18+ # Returns:
19+ # Modified GraphModule with torch._C._nn.linear replaced by torch.nn.functional.linear.
20+ # """
21+ # import ast
22+ # import torch
23+ # import types
24+ #
25+ # # First recompile to generate code
26+ # gm.recompile()
27+ #
28+ # # Use AST to modify the generated code, replacing torch._C._nn.linear with torch.nn.functional.linear
29+ # code_str = gm.code
30+ #
31+ # # Parse AST
32+ # tree = ast.parse(code_str)
33+ #
34+ # class LinearReplacer(ast.NodeTransformer):
35+ # def visit_Call(self, node):
36+ # # Check if it's a torch._C._nn.linear call
37+ # # Structure: torch._C._nn.linear(...)
38+ # filtered_nodes = [
39+ # node
40+ # for node in [node]
41+ # if isinstance(node.func, ast.Attribute)
42+ # if node.func.attr == "linear"
43+ # if isinstance(node.func.value, ast.Attribute)
44+ # if node.func.value.attr == "_nn"
45+ # if isinstance(node.func.value.value, ast.Attribute)
46+ # if node.func.value.value.attr == "_C"
47+ # if isinstance(node.func.value.value.value, ast.Name)
48+ # if node.func.value.value.value.id == "torch"
49+ # ]
50+ # if filtered_nodes:
51+ # # Found torch._C._nn.linear, replace with torch.nn.functional.linear
52+ # new_func = ast.Attribute(
53+ # value=ast.Attribute(
54+ # value=ast.Attribute(
55+ # value=ast.Name(
56+ # id="torch",
57+ # ctx=ast.Load(),
58+ # ),
59+ # attr="nn",
60+ # ctx=ast.Load(),
61+ # ),
62+ # attr="functional",
63+ # ctx=ast.Load(),
64+ # ),
65+ # attr="linear",
66+ # ctx=ast.Load(),
67+ # )
68+ # node.func = new_func
69+ # return self.generic_visit(node)
70+ #
71+ # transformer = LinearReplacer()
72+ # modified_tree = transformer.visit(tree)
73+ # ast.fix_missing_locations(modified_tree)
74+ #
75+ # # Convert the modified AST back to code string
76+ # new_code = ast.unparse(modified_tree)
77+ #
78+ # # Recompile the modified code
79+ # # Need to import device, inf and other modules that may be used
80+ # namespace = {
81+ # "torch": torch,
82+ # }
83+ # # Try to import device (if used in code)
84+ # try:
85+ # from torch import device
86+ #
87+ # namespace["device"] = device
88+ # except ImportError:
89+ # pass
90+ # # Try to import inf (if used in code)
91+ # try:
92+ # from torch import inf
93+ #
94+ # namespace["inf"] = inf
95+ # except ImportError:
96+ # # If torch doesn't have inf, use math.inf
97+ # try:
98+ # import math
99+ #
100+ # namespace["inf"] = math.inf
101+ # except:
102+ # pass
103+ #
104+ # exec(compile(modified_tree, filename="<ast>", mode="exec"), namespace)
105+ #
106+ # # Update GraphModule's forward method
107+ # forward_func = namespace.get("forward")
108+ # if forward_func:
109+ # gm.forward = types.MethodType(forward_func, gm)
110+ #
111+ # # Use serialize_graph_module_to_str to get the serialized code
112+ # # This ensures the code is properly serialized with unstable API replacements
113+ # serialized_code = serialize_graph_module_to_str(gm)
114+ # gm._code = serialized_code
115+ #
116+ # return gm
117+
118+
5119def serialize_graph_module_to_str (gm : torch .fx .GraphModule ) -> str :
6120 """
7121 Serialize a GraphModule to a string representation, replacing unstable APIs
@@ -25,7 +139,6 @@ def serialize_graph_module_to_str(gm: torch.fx.GraphModule) -> str:
25139 (r"torch\._C\._fft\.fft_rfft\(" , "torch.fft.rfft(" ),
26140 (r"torch\._C\._fft\.fft_fftn\(" , "torch.fft.fftn(" ),
27141 (r"torch\._C\._special\.special_logit\(" , "torch.special.logit(" ),
28- (r"torch\._C\._nn\.linear\(" , "torch.nn.functional.linear(" ),
29142 # replace this line with modification code for task 116 (torch._C._linalg.linalg_vector_norm)
30143 # replace this line with modification code for task 117 (torch._C._linalg.linalg_norm)
31144 # replace this line with modification code for task 118 (torch._C._nn.softplus)
@@ -35,6 +148,7 @@ def serialize_graph_module_to_str(gm: torch.fx.GraphModule) -> str:
35148 # replace this line with modification code for task 123 (torch._C._nn.pad)
36149 # replace this line with modification code for task 125 (torch._C._nn.gelu)
37150 # replace this line with modification code for task 126 (torch._C._nn.scaled_dot_product_attention)
151+ (r"torch\._C\._nn\.linear\(" , "torch.nn.functional.linear(" ),
38152 ]
39153 for pattern , repl in replacements :
40154 code = re .sub (pattern , repl , code )
0 commit comments