@@ -12,7 +12,7 @@ def __call__(self, model):
1212 self .unstable_api = unstable_api
1313
1414 def my_backend (gm , sample_inputs ):
15- gm = self .unstable_to_stable (gm )
15+ gm = self .fft_irfft_to_irfft (gm )
1616 self .check_unstable_api (gm )
1717 return gm .forward
1818
@@ -29,8 +29,34 @@ def my_backend(gm, sample_inputs):
2929 **Stable API reference link:**
3030 """
3131
32- def unstable_to_stable (self , gm ):
33- # TODO
32+ def fft_irfft_to_irfft (self , gm ):
33+ def replace_in_graph (graph_mod ):
34+ # Register stable implementation on GraphModule, codegen can use self.irfft
35+ try :
36+ setattr (graph_mod , "irfft" , torch .fft .irfft )
37+ except Exception :
38+ pass
39+
40+ for node in graph_mod .graph .nodes :
41+ if node .op == "call_function" :
42+ # Match for all forms of target names
43+ if "fft_irfft" in str (node .target ):
44+ # Directly point target to Python layer function
45+ node .target = torch .fft .irfft
46+ # Validate and recompile the graph
47+ graph_mod .graph .lint ()
48+ graph_mod .recompile ()
49+
50+ # Process main gm and all nested GraphModules
51+ modules = [gm ]
52+ modules += [
53+ m
54+ for _ , m in gm .named_modules ()
55+ if isinstance (m , torch .fx .GraphModule ) and m is not gm
56+ ]
57+ for m in modules :
58+ replace_in_graph (m )
59+
3460 return gm
3561
3662 def check_unstable_api (self , gm ):
0 commit comments