33import sys
44import inspect
55from .graph_compiler_backend import GraphCompilerBackend
6+ from ..fx_graph_serialize_util import serialize_graph_module_to_str
67
78
89class UnstableToStableBackend (GraphCompilerBackend ):
@@ -12,7 +13,7 @@ def __call__(self, model):
1213 self .unstable_api = unstable_api
1314
1415 def my_backend (gm , sample_inputs ):
15- gm = self .fft_irfft_to_irfft (gm )
16+ gm = self .unstable_to_stable (gm )
1617 self .check_unstable_api (gm )
1718 return gm .forward
1819
@@ -59,6 +60,36 @@ def replace_in_graph(graph_mod):
5960
6061 return gm
6162
63+ def avg_pool2d_to_avg_pool2d (self , gm ):
64+ """
65+ Convert torch._C._nn.avg_pool2d to torch.nn.functional.avg_pool2d
66+ """
67+ import torch .nn .functional as F
68+
69+ # Update graph nodes: replace torch._C._nn.avg_pool2d with F.avg_pool2d
70+ for node in gm .graph .nodes :
71+ if node .op == "call_function" :
72+ if (
73+ hasattr (node .target , "__module__" )
74+ and hasattr (node .target , "__name__" )
75+ and node .target .__module__ == "torch._C._nn"
76+ and node .target .__name__ == "avg_pool2d"
77+ ):
78+ node .target = F .avg_pool2d
79+
80+ # Recompile the graph
81+ gm .recompile ()
82+
83+ return gm
84+
85+ def unstable_to_stable (self , gm ):
86+ # Convert based on unstable_api environment variable
87+ if self .unstable_api == "torch._C._nn.avg_pool2d" :
88+ gm = self .avg_pool2d_to_avg_pool2d (gm )
89+ elif self .unstable_api == "torch._C._fft.fft_irfft" :
90+ gm = self .fft_irfft_to_irfft (gm )
91+ return gm
92+
6293 def check_unstable_api (self , gm ):
6394 """
6495 Check whether gm contains the API specified in the environment
@@ -70,7 +101,8 @@ def check_unstable_api(self, gm):
70101 Do NOT modify, remove, or bypass this check under any circumstances.
71102 """
72103
73- graph_text = gm .code
104+ # Use serialized code to check for unstable APIs
105+ graph_text = serialize_graph_module_to_str (gm )
74106 # Search for the unstable API substring
75107 if self .unstable_api in graph_text :
76108 count = graph_text .count (self .unstable_api )
0 commit comments