66
77
88class UnstableToStableBackend (GraphCompilerBackend ):
9- def __call__ (self , model , model_path ):
9+ def __call__ (self , model ):
1010 # Perform unstable API check before running the model
1111 unstable_api = os .getenv ("DISALLOWED_UNSTABLE_API" , "" ).strip ()
1212 self .unstable_api = unstable_api
13- self .model_path = model_path
14- self .unstable_to_stable (model )
15- self .check_unstable_api (model )
16- return self .model
13+
14+ def my_backend (gm , sample_inputs ):
15+ gm = self .unstable_to_stable (gm )
16+ self .check_unstable_api (gm )
17+ return gm .forward
18+
19+ return torch .compile (backend = my_backend )(model )
1720
1821 """
1922 TODO: Implement logic to convert unstable APIs in `self.model` into their stable counterparts.
@@ -26,10 +29,11 @@ def __call__(self, model, model_path):
2629 **Stable API reference link:**
2730 """
2831
29- def unstable_to_stable (self , model ):
30- return
32+ def unstable_to_stable (self , gm ):
33+ # TODO
34+ return gm
3135
32- def check_unstable_api (self , model ):
36+ def check_unstable_api (self , gm ):
3337 """
3438 Check whether gm contains the API specified in the environment
3539 variable DISALLOWED_UNSTABLE_API. If it does, raise an exception and stop
@@ -40,20 +44,7 @@ def check_unstable_api(self, model):
4044 Do NOT modify, remove, or bypass this check under any circumstances.
4145 """
4246
43- # from torch.fx import symbolic_trace
44-
45- # try:
46- # # Convert the model into a static computation graph (FX IR)
47- # traced = symbolic_trace(self.model)
48- # graph_text = str(traced.graph)
49- # except Exception as e:
50- # # In case tracing fails, fallback to textual model dump
51- # graph_text = str(*(self.model))
52-
53- print (f"model path is: { self .model_path } " )
54- model_file_path = self .model_path + "model.py"
55- with open (model_file_path , "r" , encoding = "utf-8" ) as f :
56- graph_text = f .read ()
47+ graph_text = gm .code
5748 # Search for the unstable API substring
5849 if self .unstable_api in graph_text :
5950 count = graph_text .count (self .unstable_api )
0 commit comments