Skip to content

Commit 06227d7

Browse files
committed
Use ast to generate the hash of Paddle models.
1 parent 2b62d8f commit 06227d7

File tree

1 file changed

+20
-15
lines changed

1 file changed

+20
-15
lines changed

graph_net/paddle/validate.py

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import numpy as np
1212
import graph_net
1313
import os
14-
import re
14+
import ast
1515
import paddle
1616

1717

@@ -31,26 +31,32 @@ def _get_sha_hash(content):
3131

3232
def _save_to_model_path(dump_dir, hash_text):
3333
file_path = f"{dump_dir}/graph_hash.txt"
34+
print(f"Writing to {file_path}.")
3435
with open(file_path, "w") as f:
3536
f.write(hash_text)
3637

3738

38-
def extract_from_forward_regex(text, case_sensitive=True):
39-
pattern = r"forward.*"
40-
flags = 0 if case_sensitive else re.IGNORECASE
39+
def extract_forward_source(model_path):
40+
source = None
41+
with open(f"{model_path}/model.py", "r") as f:
42+
source = f.read()
4143

42-
match = re.search(pattern, text, flags)
43-
if match:
44-
return match.group(0)
45-
else:
46-
raise ValueError("Erroneous case occurs.")
44+
tree = ast.parse(source)
45+
forward_code = None
46+
47+
for node in tree.body:
48+
if isinstance(node, ast.ClassDef) and node.name == "GraphModule":
49+
for fn in node.body:
50+
if isinstance(fn, ast.FunctionDef) and fn.name == "forward":
51+
return ast.unparse(fn)
52+
return None
4753

4854

4955
def main(args):
5056
model_path = args.model_path
51-
with open(f"{model_path}/model.py", "r") as fp:
52-
model_str = fp.read()
53-
model_str = extract_from_forward_regex(model_str)
57+
if args.dump_graph_hash_key:
58+
model_str = extract_forward_source(model_path)
59+
assert model_str is not None, f"model_str of {args.model_path} is None."
5460
_save_to_model_path(model_path, _get_sha_hash(model_str))
5561

5662
model_path = args.model_path
@@ -100,17 +106,16 @@ def main(args):
100106
required=True,
101107
help="Path to folder e.g '../test_dataset'",
102108
)
103-
104109
parser.add_argument(
105110
"--no-check-redundancy",
106111
action="store_true",
112+
default=False,
107113
help="whether check model graph redundancy",
108114
)
109-
110115
parser.add_argument(
111116
"--dump-graph-hash-key",
112117
action="store_true",
113-
default=False,
118+
default=True,
114119
help="Dump graph hash key",
115120
)
116121
parser.add_argument(

0 commit comments

Comments
 (0)