Skip to content

Commit 02ebe4a

Browse files
committed
resolve merge conflicts
1 parent 70996ea commit 02ebe4a

File tree

2 files changed

+42
-14
lines changed

2 files changed

+42
-14
lines changed

graph_net/torch/backend/unstable_to_stable_backend.py

Lines changed: 32 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -126,50 +126,70 @@ def _impl_unstable_to_stable_fftn(self, gm):
126126

127127
return gm
128128

129-
def _impl_unstable_to_stable_one_hot(self, gm):
129+
def _impl_unstable_to_stable_special_logit(self, gm):
130130
"""
131-
Convert torch._C._nn.one_hot to torch.nn.functional.one_hot
131+
Convert torch._C._special.special_logit to torch.special.logit
132132
"""
133-
import torch.nn.functional as F
134-
135133
issue_nodes = (
136134
node
137135
for node in gm.graph.nodes
138136
if node.op == "call_function"
139137
if hasattr(node.target, "__module__")
140-
if node.target.__module__ == "torch._C._nn"
138+
if node.target.__module__ == "torch._C._special"
141139
if hasattr(node.target, "__name__")
142-
if node.target.__name__ == "one_hot"
140+
if node.target.__name__ == "special_logit"
143141
)
144142
for node in issue_nodes:
145-
node.target = F.one_hot
143+
node.target = torch.special.logit
146144

147145
# Recompile the graph
148146
gm.recompile()
149147

150148
return gm
151149

152-
def _impl_unstable_to_stable_special_logit(self, gm):
150+
# replace this line with modification code for task 116 (torch._C._linalg.linalg_vector_norm)
151+
152+
# replace this line with modification code for task 117 (torch._C._linalg.linalg_norm)
153+
154+
# replace this line with modification code for task 118 (torch._C._nn.softplus)
155+
156+
# replace this line with modification code for task 119 (torch._C._nn.one_hot)
157+
158+
def _impl_unstable_to_stable_one_hot(self, gm):
153159
"""
154-
Convert torch._C._special.special_logit to torch.special.logit
160+
Convert torch._C._nn.one_hot to torch.nn.functional.one_hot
155161
"""
162+
import torch.nn.functional as F
163+
156164
issue_nodes = (
157165
node
158166
for node in gm.graph.nodes
159167
if node.op == "call_function"
160168
if hasattr(node.target, "__module__")
161-
if node.target.__module__ == "torch._C._special"
169+
if node.target.__module__ == "torch._C._nn"
162170
if hasattr(node.target, "__name__")
163-
if node.target.__name__ == "special_logit"
171+
if node.target.__name__ == "one_hot"
164172
)
165173
for node in issue_nodes:
166-
node.target = torch.special.logit
174+
node.target = F.one_hot
167175

168176
# Recompile the graph
169177
gm.recompile()
170178

171179
return gm
172180

181+
# replace this line with modification code for task 121 (torch._C._set_grad_enabled)
182+
183+
# replace this line with modification code for task 122 (torch._C._log_api_usage_once)
184+
185+
# replace this line with modification code for task 123 (torch._C._nn.pad)
186+
187+
# replace this line with modification code for task 125 (torch._C._nn.gelu)
188+
189+
# replace this line with modification code for task 126 (torch._C._nn.scaled_dot_product_attention)
190+
191+
# replace this line with modification code for task 127 (torch._C._nn.linear)
192+
173193
def unstable_to_stable(self, gm):
174194
methods = (
175195
name

graph_net/torch/fx_graph_serialize_util.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,17 @@ def serialize_graph_module_to_str(gm: torch.fx.GraphModule) -> str:
2424
(r"torch\._C\._fft\.fft_irfft\(", "torch.fft.irfft("),
2525
(r"torch\._C\._fft\.fft_rfft\(", "torch.fft.rfft("),
2626
(r"torch\._C\._fft\.fft_fftn\(", "torch.fft.fftn("),
27-
(r"torch\._C\._nn\.one_hot\(", "torch.nn.functional.one_hot("),
2827
(r"torch\._C\._special\.special_logit\(", "torch.special.logit("),
29-
# Add new rules to this list as needed
28+
# replace this line with modification code for task 116 (torch._C._linalg.linalg_vector_norm)
29+
# replace this line with modification code for task 117 (torch._C._linalg.linalg_norm)
30+
# replace this line with modification code for task 118 (torch._C._nn.softplus)
31+
(r"torch\._C\._nn\.one_hot\(", "torch.nn.functional.one_hot("),
32+
# replace this line with modification code for task 121 (torch._C._set_grad_enabled)
33+
# replace this line with modification code for task 122 (torch._C._log_api_usage_once)
34+
# replace this line with modification code for task 123 (torch._C._nn.pad)
35+
# replace this line with modification code for task 125 (torch._C._nn.gelu)
36+
# replace this line with modification code for task 126 (torch._C._nn.scaled_dot_product_attention)
37+
# replace this line with modification code for task 127 (torch._C._nn.linear)
3038
]
3139
for pattern, repl in replacements:
3240
code = re.sub(pattern, repl, code)

0 commit comments

Comments
 (0)