Skip to content

Commit 5b16d63

Browse files
authored
init (#421)
1 parent bea41c2 commit 5b16d63

File tree

2 files changed

+24
-4
lines changed

2 files changed

+24
-4
lines changed

graph_net/torch/backend/unstable_to_stable_backend.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
import os
22
import torch
33
import sys
4-
import inspect
5-
import ast
64
from .graph_compiler_backend import GraphCompilerBackend
75
from ..fx_graph_serialize_util import serialize_graph_module_to_str
86

@@ -318,7 +316,26 @@ def replace_in_graph(graph_mod):
318316

319317
return gm
320318

321-
# replace this line with modification code for task 126 (torch._C._nn.scaled_dot_product_attention)
319+
def _impl_unstable_to_stable_sdpa(self, gm):
320+
"""
321+
Convert torch._C._nn.scaled_dot_product_attention to torch.nn.functional.scaled_dot_product_attention
322+
"""
323+
issue_nodes = (
324+
node
325+
for node in gm.graph.nodes
326+
if node.op == "call_function"
327+
if hasattr(node.target, "__module__")
328+
if node.target.__module__ == "torch._C._nn"
329+
if hasattr(node.target, "__name__")
330+
if node.target.__name__ == "scaled_dot_product_attention"
331+
)
332+
333+
for node in issue_nodes:
334+
node.target = torch.nn.functional.scaled_dot_product_attention
335+
336+
gm.recompile()
337+
338+
return gm
322339

323340
def _impl_unstable_to_stable_linear_to_functional_linear(self, gm):
324341
"""

graph_net/torch/fx_graph_serialize_util.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,10 @@ def serialize_graph_module_to_str(gm: torch.fx.GraphModule) -> str:
148148
# replace this line with modification code for task 122 (torch._C._log_api_usage_once)
149149
(r"torch\._C\._nn\.pad\(", "torch.nn.functional.pad("),
150150
(r"torch\._C\._nn\.gelu\(", "torch.nn.functional.gelu("),
151-
# replace this line with modification code for task 126 (torch._C._nn.scaled_dot_product_attention)
151+
(
152+
r"torch\._C\._nn\.scaled_dot_product_attention\(",
153+
"torch.nn.functional.scaled_dot_product_attention(",
154+
),
152155
(r"torch\._C\._nn\.linear\(", "torch.nn.functional.linear("),
153156
]
154157
for pattern, repl in replacements:

0 commit comments

Comments
 (0)