Skip to content

Commit 7022de4

Browse files
committed
add graph_net/torch/unittest/test_ast_renamer.py
1 parent d9b01dd commit 7022de4

File tree

5 files changed

+119
-2
lines changed

5 files changed

+119
-2
lines changed

.github/workflows/Validate-GPU.yml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,14 @@ jobs:
8585
bash ${work_dir}/tools/ci/check_validate.sh
8686
'
8787
88+
- name: Run AST Renamer Unit Test
89+
if: steps.check-bypass.outputs.can-skip != 'true'
90+
run: |
91+
docker exec -t ${{ env.container_name }} /bin/bash -c '
92+
export PYTHONPATH=$PYTHONPATH:${{ github.workspace }}
93+
python3 -m unittest graph_net/torch/unittest/test_ast_renamer.py
94+
'
95+
8896
- name: Terminate and delete the container
8997
if: always()
9098
run: |

graph_net/sample_pass/ast_graph_variable_renamer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -232,9 +232,9 @@ def _clean_delete(self, stmt):
232232

233233
def _filter_delete_target(self, target):
234234
if isinstance(target, ast.Tuple): # del (a, b)
235-
kept_elts = [e for e in target.elts if not self._is_protected_var(e)]
235+
kept_elts = [e for e in target.elts if not self._is_input_or_weight_var(e)]
236236
return ast.Tuple(elts=kept_elts, ctx=ast.Del()) if kept_elts else None
237-
elif not self._is_protected_var(target): # del a
237+
elif not self._is_input_or_weight_var(target):
238238
return target
239239
else:
240240
pass

graph_net/tools/generate_subgraph_dataset.sh

100644100755
File mode changed.

graph_net/tools/get_in_tensor_symbolic_shapes.sh

100644100755
File mode changed.
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
import unittest
2+
import shutil
3+
import tempfile
4+
import textwrap
5+
from pathlib import Path
6+
7+
from graph_net.sample_pass.ast_graph_variable_renamer import AstGraphVariableRenamer
8+
from graph_net.tensor_meta import TensorMeta
9+
10+
11+
class TestAstGraphVariableRenamerProduction(unittest.TestCase):
12+
def setUp(self):
13+
self.test_root = Path(tempfile.mkdtemp())
14+
self.src_repo = self.test_root / "src_repo"
15+
self.dst_workspace = self.test_root / "workspace"
16+
self.src_repo.mkdir()
17+
self.dst_workspace.mkdir()
18+
19+
self.rel_model_path = "samples/demo_model"
20+
self.full_src_path = self.src_repo / self.rel_model_path
21+
self.full_src_path.mkdir(parents=True)
22+
23+
(self.full_src_path / "model.py").write_text(
24+
textwrap.dedent(
25+
"""
26+
import torch
27+
class GraphModule(torch.nn.Module):
28+
def forward(self, x_data, L__self___weight):
29+
res = x_data + L__self___weight
30+
L__self___weight = None
31+
return res
32+
"""
33+
)
34+
)
35+
36+
def save_meta_file(filename, var_name):
37+
meta = TensorMeta(
38+
name=var_name,
39+
shape=[1, 10],
40+
dtype="torch.float32",
41+
record_class_name="TensorMeta",
42+
original_name=None,
43+
device="cpu",
44+
mean=0.0,
45+
std=1.0,
46+
data=None,
47+
max_val=1.0,
48+
min_val=-1.0,
49+
)
50+
(self.full_src_path / filename).write_text(meta.serialize_to_py_str())
51+
52+
save_meta_file("input_meta.py", "x_data")
53+
save_meta_file("weight_meta.py", "L__self___weight")
54+
55+
import graph_net.torch.constraint_util as cu
56+
57+
self.real_constraint_path = cu.__file__
58+
59+
def tearDown(self):
60+
shutil.rmtree(self.test_root)
61+
62+
def test_end_to_end_renaming_logic(self):
63+
handler_config = {
64+
"device": "cpu",
65+
"resume": True,
66+
"try_run": False,
67+
"model_path_prefix": str(self.src_repo),
68+
"output_dir": str(self.dst_workspace),
69+
"data_input_predicator_filepath": self.real_constraint_path,
70+
"data_input_predicator_class_name": "NaiveDataInputPredicator",
71+
"data_input_predicator_config": {},
72+
"model_runnable_predicator_filepath": self.real_constraint_path,
73+
"model_runnable_predicator_class_name": "ModelRunnablePredicator",
74+
"model_runnable_predicator_config": {},
75+
}
76+
77+
renamer = AstGraphVariableRenamer(handler_config)
78+
renamer(self.rel_model_path)
79+
80+
target_dir = self.dst_workspace / self.rel_model_path
81+
new_code = (target_dir / "model.py").read_text()
82+
self.assertIn("in_0", new_code, "x_data 应该被识别为 in_0")
83+
self.assertIn("w_0", new_code, "L__self___weight 应该被识别为 w_0")
84+
self.assertIn("tmp_0", new_code, "中间变量 res 应该被重命名为 tmp_0")
85+
self.assertNotIn("None", new_code, "权重清理语句应被 AST 转换器删除")
86+
87+
new_weight_metas = TensorMeta.unserialize_from_py_file(
88+
str(target_dir / "weight_meta.py")
89+
)
90+
self.assertEqual(new_weight_metas[0].name, "w_0")
91+
self.assertEqual(new_weight_metas[0].original_name, "L__self___weight")
92+
93+
self.assertTrue((target_dir / "graph_hash.txt").exists())
94+
hash_val = (target_dir / "graph_hash.txt").read_text()
95+
self.assertEqual(len(hash_val), 64, "Hash 应为标准的 SHA256 长度")
96+
97+
def test_predicator_classification_diagnostic(self):
98+
from graph_net.imp_util import load_module
99+
100+
module = load_module(self.real_constraint_path)
101+
pred_cls = getattr(module, "NaiveDataInputPredicator")
102+
predicator = pred_cls({})
103+
104+
self.assertFalse(predicator(None, "L__self___weight"))
105+
self.assertTrue(predicator(None, "random_var_name"))
106+
107+
108+
if __name__ == "__main__":
109+
unittest.main()

0 commit comments

Comments
 (0)