Skip to content

Commit 5c8140d

Browse files
committed
add unstable_to_stable backend and check_unstable_api
1 parent 377e935 commit 5c8140d

File tree

2 files changed

+69
-1
lines changed

2 files changed

+69
-1
lines changed
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
import os
2+
import torch
3+
from .graph_compiler_backend import GraphCompilerBackend
4+
5+
6+
class UnstableToStableBackend(GraphCompilerBackend):
7+
def __call__(self, model):
8+
# Perform unstable API check before running the model
9+
self.model = model
10+
self.unstable_to_stable()
11+
self.check_unstable_api()
12+
return self.model
13+
14+
"""
15+
TODO: 实现将 self.model 中的不稳定(unstable)API 转换为稳定(stable)API 的逻辑。
16+
该 API 负责遍历 self.model,并将其中调用的实验性或不稳定接口替换为对应的稳定版本。
17+
注意:此逻辑属于模型编译安全机制的重要组成部分,请勿随意修改或删除。
18+
19+
api命名规范:
20+
<unstable>_to_<stable>
21+
22+
stable api链接:
23+
"""
24+
25+
def unstable_to_stable(self):
26+
return
27+
28+
def check_unstable_api(self):
29+
"""
30+
Check whether gm contains the API specified in the environment
31+
variable DISALLOWED_UNSTABLE_API. If it does, raise an exception and stop
32+
execution immediately.
33+
34+
IMPORTANT:
35+
This logic is part of the GraphNet compiler safety mechanism.
36+
Do NOT modify, remove, or bypass this check under any circumstances.
37+
"""
38+
unstable_api = os.getenv("DISALLOWED_UNSTABLE_API", "").strip()
39+
if not unstable_api:
40+
return # Skip check if no environment variable is set
41+
42+
from torch.fx import symbolic_trace
43+
44+
try:
45+
# Convert the model into a static computation graph (FX IR)
46+
traced = symbolic_trace(self.model)
47+
graph_text = str(traced.graph)
48+
except Exception as e:
49+
# In case tracing fails, fallback to textual model dump
50+
graph_text = str(self.model)
51+
52+
# Search for the unstable API substring
53+
if unstable_api in graph_text:
54+
count = graph_text.count(unstable_api)
55+
raise RuntimeError(
56+
f"❌ Detected unstable API '{unstable_api}' '{count}' times in model graph.\n"
57+
f"Please replace it with a stable API before proceeding.\n"
58+
)
59+
else:
60+
print(f"✅ Model passed: no occurrence of '{unstable_api}' found.")
61+
62+
def synchronize(self):
63+
# Synchronize CUDA operations if available
64+
if torch.cuda.is_available():
65+
torch.cuda.synchronize()

graph_net/torch/test_compiler.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,18 @@
2121
from graph_net.torch.backend.tensorrt_backend import TensorRTBackend
2222
from graph_net.torch.backend.blade_disc_backend import BladeDISCBackend
2323
from graph_net.torch.backend.nope_backend import NopeBackend
24+
from graph_net.torch.backend.unstable_to_stable_backend import UnstableToStableBackend
2425
from graph_net.test_compiler_util import generate_allclose_configs
2526

27+
2628
registry_backend = {
2729
"tvm": TvmBackend(),
2830
"xla": XlaBackend(),
2931
"inductor": InductorBackend(),
3032
"tensorrt": TensorRTBackend(),
3133
"bladedisc": BladeDISCBackend(),
3234
"nope": NopeBackend(),
35+
"unstable_to_stable": UnstableToStableBackend(),
3336
}
3437

3538

@@ -215,7 +218,7 @@ def test_single_model(args):
215218
)
216219

217220
version_str = "unknown"
218-
if args.compiler == "inductor":
221+
if args.compiler in ["inductor", "unstable_to_stable"]:
219222
version_str = torch.__version__
220223
elif args.compiler in ["tvm", "xla", "tensorrt", "bladedisc"]:
221224
# Assuming compiler object has a version attribute

0 commit comments

Comments
 (0)