1111import numpy as np
1212import graph_net
1313import os
14- import re
14+ import ast
1515import paddle
1616
1717
@@ -29,29 +29,45 @@ def _get_sha_hash(content):
2929 return m .hexdigest ()
3030
3131
32- def _save_to_model_path ( dump_dir , hash_text ):
33- file_path = f" { dump_dir } /graph_hash.txt"
34- with open (file_path , "w " ) as f :
35- f . write ( hash_text )
32+ def _extract_forward_source ( model_path ):
33+ source = None
34+ with open (f" { model_path } /model.py" , "r " ) as f :
35+ source = f . read ( )
3636
37+ tree = ast .parse (source )
38+ forward_code = None
3739
38- def extract_from_forward_regex (text , case_sensitive = True ):
39- pattern = r"forward.*"
40- flags = 0 if case_sensitive else re .IGNORECASE
40+ for node in tree .body :
41+ if isinstance (node , ast .ClassDef ) and node .name == "GraphModule" :
42+ for fn in node .body :
43+ if isinstance (fn , ast .FunctionDef ) and fn .name == "forward" :
44+ return ast .unparse (fn )
45+ return None
4146
42- match = re .search (pattern , text , flags )
43- if match :
44- return match .group (0 )
47+
48+ def check_graph_hash (args ):
49+ model_path = args .model_path
50+ file_path = f"{ model_path } /graph_hash.txt"
51+ if args .dump_graph_hash_key :
52+ model_str = _extract_forward_source (model_path )
53+ assert model_str is not None , f"model_str of { args .model_path } is None."
54+ new_hash_text = _get_sha_hash (model_str )
55+ if os .path .exists (file_path ):
56+ with open (file_path , "r" ) as f :
57+ old_hash_text = f .read ()
58+ assert (
59+ new_hash_text == old_hash_text
60+ ), f"Hash value for { model_path } is not consistent."
61+ else :
62+ print (f"Writing to { file_path } ." )
63+ with open (file_path , "w" ) as f :
64+ f .write (new_hash_text )
4565 else :
46- raise ValueError ( "Erroneous case occurs." )
66+ assert os . path . exists ( file_path ), f" { file_path } does not exist."
4767
4868
4969def main (args ):
50- model_path = args .model_path
51- with open (f"{ model_path } /model.py" , "r" ) as fp :
52- model_str = fp .read ()
53- model_str = extract_from_forward_regex (model_str )
54- _save_to_model_path (model_path , _get_sha_hash (model_str ))
70+ check_graph_hash (args )
5571
5672 model_path = args .model_path
5773 model_class = load_class_from_file (
@@ -100,17 +116,16 @@ def main(args):
100116 required = True ,
101117 help = "Path to folder e.g '../test_dataset'" ,
102118 )
103-
104119 parser .add_argument (
105120 "--no-check-redundancy" ,
106121 action = "store_true" ,
122+ default = False ,
107123 help = "whether check model graph redundancy" ,
108124 )
109-
110125 parser .add_argument (
111126 "--dump-graph-hash-key" ,
112127 action = "store_true" ,
113- default = False ,
128+ default = True ,
114129 help = "Dump graph hash key" ,
115130 )
116131 parser .add_argument (
0 commit comments