Skip to content

Commit c9a2358

Browse files
authored
unstable to stable challenge (#313)
1 parent a1945ba commit c9a2358

File tree

3 files changed

+21
-26
lines changed

3 files changed

+21
-26
lines changed

graph_net/torch/backend/unstable_to_stable_backend.py

Lines changed: 13 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,17 @@
66

77

88
class UnstableToStableBackend(GraphCompilerBackend):
9-
def __call__(self, model, model_path):
9+
def __call__(self, model):
1010
# Perform unstable API check before running the model
1111
unstable_api = os.getenv("DISALLOWED_UNSTABLE_API", "").strip()
1212
self.unstable_api = unstable_api
13-
self.model_path = model_path
14-
self.unstable_to_stable(model)
15-
self.check_unstable_api(model)
16-
return self.model
13+
14+
def my_backend(gm, sample_inputs):
15+
gm = self.unstable_to_stable(gm)
16+
self.check_unstable_api(gm)
17+
return gm.forward
18+
19+
return torch.compile(backend=my_backend)(model)
1720

1821
"""
1922
TODO: Implement logic to convert unstable APIs in `self.model` into their stable counterparts.
@@ -26,10 +29,11 @@ def __call__(self, model, model_path):
2629
**Stable API reference link:**
2730
"""
2831

29-
def unstable_to_stable(self, model):
30-
return
32+
def unstable_to_stable(self, gm):
33+
# TODO
34+
return gm
3135

32-
def check_unstable_api(self, model):
36+
def check_unstable_api(self, gm):
3337
"""
3438
Check whether gm contains the API specified in the environment
3539
variable DISALLOWED_UNSTABLE_API. If it does, raise an exception and stop
@@ -40,20 +44,7 @@ def check_unstable_api(self, model):
4044
Do NOT modify, remove, or bypass this check under any circumstances.
4145
"""
4246

43-
# from torch.fx import symbolic_trace
44-
45-
# try:
46-
# # Convert the model into a static computation graph (FX IR)
47-
# traced = symbolic_trace(self.model)
48-
# graph_text = str(traced.graph)
49-
# except Exception as e:
50-
# # In case tracing fails, fallback to textual model dump
51-
# graph_text = str(*(self.model))
52-
53-
print(f"model path is: {self.model_path}")
54-
model_file_path = self.model_path + "model.py"
55-
with open(model_file_path, "r", encoding="utf-8") as f:
56-
graph_text = f.read()
47+
graph_text = gm.code
5748
# Search for the unstable API substring
5849
if self.unstable_api in graph_text:
5950
count = graph_text.count(self.unstable_api)

graph_net/torch/test_compiler.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -289,8 +289,6 @@ def test_single_model(args):
289289
if args.compiler == "xla":
290290
xla_model = get_model(args, "xla")
291291
compiled_model = compiler(xla_model)
292-
elif args.compiler == "unstable_to_stable":
293-
compiled_model = compiler(model, args.model_path)
294292
else:
295293
compiled_model = compiler(model)
296294

plot_unstable_to_stable.sh

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,14 @@ if [ -z "$DISALLOWED_UNSTABLE_API" ]; then
77
exit 1
88
fi
99

10+
if [ -z "$GRAPH_NET_WORKSPACE" ]; then
11+
echo "❌ 环境变量 GRAPH_NET_WORKSPACE 未设置!"
12+
echo "请使用: export GRAPH_NET_WORKSPACE=/path/to/GraphNet"
13+
exit 1
14+
fi
15+
1016
# === 配置区 ===
11-
root_dir="/root/GraphNet/todo_works/unstable_api_to_stable_api/${DISALLOWED_UNSTABLE_API}"
17+
root_dir="${GRAPH_NET_WORKSPACE}/todo_works/unstable_api_to_stable_api/${DISALLOWED_UNSTABLE_API}"
1218
file_list="${root_dir}/${DISALLOWED_UNSTABLE_API}_files.txt"
1319
log_file="${root_dir}/log.log"
1420
json_output_dir="${root_dir}/JSON_results"
@@ -68,7 +74,7 @@ else
6874
fi
6975

7076
echo "📦 正在将JSON转换为结果图"
71-
python -m graph_net.S_analysis \
77+
python -m graph_net.plot_ESt \
7278
--benchmark-path $GRAPH_NET_BENCHMARK_PATH/JSON_results/ \
7379
--output-dir $GRAPH_NET_BENCHMARK_PATH \
7480

0 commit comments

Comments
 (0)