File tree Expand file tree Collapse file tree 2 files changed +24
-4
lines changed
Expand file tree Collapse file tree 2 files changed +24
-4
lines changed Original file line number Diff line number Diff line change 11import os
22import torch
33import sys
4- import inspect
5- import ast
64from .graph_compiler_backend import GraphCompilerBackend
75from ..fx_graph_serialize_util import serialize_graph_module_to_str
86
@@ -318,7 +316,26 @@ def replace_in_graph(graph_mod):
318316
319317 return gm
320318
321- # replace this line with modification code for task 126 (torch._C._nn.scaled_dot_product_attention)
319+ def _impl_unstable_to_stable_sdpa (self , gm ):
320+ """
321+ Convert torch._C._nn.scaled_dot_product_attention to torch.nn.functional.scaled_dot_product_attention
322+ """
323+ issue_nodes = (
324+ node
325+ for node in gm .graph .nodes
326+ if node .op == "call_function"
327+ if hasattr (node .target , "__module__" )
328+ if node .target .__module__ == "torch._C._nn"
329+ if hasattr (node .target , "__name__" )
330+ if node .target .__name__ == "scaled_dot_product_attention"
331+ )
332+
333+ for node in issue_nodes :
334+ node .target = torch .nn .functional .scaled_dot_product_attention
335+
336+ gm .recompile ()
337+
338+ return gm
322339
323340 def _impl_unstable_to_stable_linear_to_functional_linear (self , gm ):
324341 """
Original file line number Diff line number Diff line change @@ -148,7 +148,10 @@ def serialize_graph_module_to_str(gm: torch.fx.GraphModule) -> str:
148148 # replace this line with modification code for task 122 (torch._C._log_api_usage_once)
149149 (r"torch\._C\._nn\.pad\(" , "torch.nn.functional.pad(" ),
150150 (r"torch\._C\._nn\.gelu\(" , "torch.nn.functional.gelu(" ),
151- # replace this line with modification code for task 126 (torch._C._nn.scaled_dot_product_attention)
151+ (
152+ r"torch\._C\._nn\.scaled_dot_product_attention\(" ,
153+ "torch.nn.functional.scaled_dot_product_attention(" ,
154+ ),
152155 (r"torch\._C\._nn\.linear\(" , "torch.nn.functional.linear(" ),
153156 ]
154157 for pattern , repl in replacements :
You can’t perform that action at this time.
0 commit comments