@@ -30,6 +30,36 @@ def my_backend(gm, sample_inputs):
3030 **Stable API reference link:**
3131 """
3232
33+ def fft_irfft_to_irfft (self , gm ):
34+ def replace_in_graph (graph_mod ):
35+ # Register stable implementation on GraphModule, codegen can use self.irfft
36+ try :
37+ setattr (graph_mod , "irfft" , torch .fft .irfft )
38+ except Exception :
39+ pass
40+
41+ for node in graph_mod .graph .nodes :
42+ if node .op == "call_function" :
43+ # Match for all forms of target names
44+ if "fft_irfft" in str (node .target ):
45+ # Directly point target to Python layer function
46+ node .target = torch .fft .irfft
47+ # Validate and recompile the graph
48+ graph_mod .graph .lint ()
49+ graph_mod .recompile ()
50+
51+ # Process main gm and all nested GraphModules
52+ modules = [gm ]
53+ modules += [
54+ m
55+ for _ , m in gm .named_modules ()
56+ if isinstance (m , torch .fx .GraphModule ) and m is not gm
57+ ]
58+ for m in modules :
59+ replace_in_graph (m )
60+
61+ return gm
62+
3363 def avg_pool2d_to_avg_pool2d (self , gm ):
3464 """
3565 Convert torch._C._nn.avg_pool2d to torch.nn.functional.avg_pool2d
@@ -56,6 +86,8 @@ def unstable_to_stable(self, gm):
5686 # Convert based on unstable_api environment variable
5787 if self .unstable_api == "torch._C._nn.avg_pool2d" :
5888 gm = self .avg_pool2d_to_avg_pool2d (gm )
89+ elif self .unstable_api == "torch._C._fft.fft_irfft" :
90+ gm = self .fft_irfft_to_irfft (gm )
5991 return gm
6092
6193 def check_unstable_api (self , gm ):
0 commit comments