Skip to content

Commit 6910563

Browse files
committed
Use ast to generate the hash of Paddle models.
1 parent 2b62d8f commit 6910563

File tree

1 file changed

+35
-20
lines changed

1 file changed

+35
-20
lines changed

graph_net/paddle/validate.py

Lines changed: 35 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import numpy as np
1212
import graph_net
1313
import os
14-
import re
14+
import ast
1515
import paddle
1616

1717

@@ -29,29 +29,45 @@ def _get_sha_hash(content):
2929
return m.hexdigest()
3030

3131

32-
def _save_to_model_path(dump_dir, hash_text):
33-
file_path = f"{dump_dir}/graph_hash.txt"
34-
with open(file_path, "w") as f:
35-
f.write(hash_text)
32+
def _extract_forward_source(model_path):
33+
source = None
34+
with open(f"{model_path}/model.py", "r") as f:
35+
source = f.read()
3636

37+
tree = ast.parse(source)
38+
forward_code = None
3739

38-
def extract_from_forward_regex(text, case_sensitive=True):
39-
pattern = r"forward.*"
40-
flags = 0 if case_sensitive else re.IGNORECASE
40+
for node in tree.body:
41+
if isinstance(node, ast.ClassDef) and node.name == "GraphModule":
42+
for fn in node.body:
43+
if isinstance(fn, ast.FunctionDef) and fn.name == "forward":
44+
return ast.unparse(fn)
45+
return None
4146

42-
match = re.search(pattern, text, flags)
43-
if match:
44-
return match.group(0)
47+
48+
def check_graph_hash(args):
49+
model_path = args.model_path
50+
file_path = f"{model_path}/graph_hash.txt"
51+
if args.dump_graph_hash_key:
52+
model_str = _extract_forward_source(model_path)
53+
assert model_str is not None, f"model_str of {args.model_path} is None."
54+
new_hash_text = _get_sha_hash(model_str)
55+
if os.path.exists(file_path):
56+
with open(file_path, "r") as f:
57+
old_hash_text = f.read()
58+
assert (
59+
new_hash_text == old_hash_text
60+
), f"Hash value for {model_path} is not consistent."
61+
else:
62+
print(f"Writing to {file_path}.")
63+
with open(file_path, "w") as f:
64+
f.write(new_hash_text)
4565
else:
46-
raise ValueError("Erroneous case occurs.")
66+
assert os.path.exists(file_path), f"{file_path} does not exist."
4767

4868

4969
def main(args):
50-
model_path = args.model_path
51-
with open(f"{model_path}/model.py", "r") as fp:
52-
model_str = fp.read()
53-
model_str = extract_from_forward_regex(model_str)
54-
_save_to_model_path(model_path, _get_sha_hash(model_str))
70+
check_graph_hash(args)
5571

5672
model_path = args.model_path
5773
model_class = load_class_from_file(
@@ -100,17 +116,16 @@ def main(args):
100116
required=True,
101117
help="Path to folder e.g '../test_dataset'",
102118
)
103-
104119
parser.add_argument(
105120
"--no-check-redundancy",
106121
action="store_true",
122+
default=False,
107123
help="whether check model graph redundancy",
108124
)
109-
110125
parser.add_argument(
111126
"--dump-graph-hash-key",
112127
action="store_true",
113-
default=False,
128+
default=True,
114129
help="Dump graph hash key",
115130
)
116131
parser.add_argument(

0 commit comments

Comments
 (0)