1111import numpy as np
1212import graph_net
1313import os
14- import re
14+ import ast
1515import paddle
1616
1717
@@ -31,26 +31,32 @@ def _get_sha_hash(content):
3131
3232def _save_to_model_path (dump_dir , hash_text ):
3333 file_path = f"{ dump_dir } /graph_hash.txt"
34+ print (f"Writing to { file_path } ." )
3435 with open (file_path , "w" ) as f :
3536 f .write (hash_text )
3637
3738
38- def extract_from_forward_regex (text , case_sensitive = True ):
39- pattern = r"forward.*"
40- flags = 0 if case_sensitive else re .IGNORECASE
39+ def extract_forward_source (model_path ):
40+ source = None
41+ with open (f"{ model_path } /model.py" , "r" ) as f :
42+ source = f .read ()
4143
42- match = re .search (pattern , text , flags )
43- if match :
44- return match .group (0 )
45- else :
46- raise ValueError ("Erroneous case occurs." )
44+ tree = ast .parse (source )
45+ forward_code = None
46+
47+ for node in tree .body :
48+ if isinstance (node , ast .ClassDef ) and node .name == "GraphModule" :
49+ for fn in node .body :
50+ if isinstance (fn , ast .FunctionDef ) and fn .name == "forward" :
51+ return ast .unparse (fn )
52+ return None
4753
4854
4955def main (args ):
5056 model_path = args .model_path
51- with open ( f" { model_path } /model.py" , "r" ) as fp :
52- model_str = fp . read ( )
53- model_str = extract_from_forward_regex ( model_str )
57+ if args . dump_graph_hash_key :
58+ model_str = extract_forward_source ( model_path )
59+ assert model_str is not None , f" model_str of { args . model_path } is None."
5460 _save_to_model_path (model_path , _get_sha_hash (model_str ))
5561
5662 model_path = args .model_path
@@ -100,17 +106,16 @@ def main(args):
100106 required = True ,
101107 help = "Path to folder e.g '../test_dataset'" ,
102108 )
103-
104109 parser .add_argument (
105110 "--no-check-redundancy" ,
106111 action = "store_true" ,
112+ default = False ,
107113 help = "whether check model graph redundancy" ,
108114 )
109-
110115 parser .add_argument (
111116 "--dump-graph-hash-key" ,
112117 action = "store_true" ,
113- default = False ,
118+ default = True ,
114119 help = "Dump graph hash key" ,
115120 )
116121 parser .add_argument (
0 commit comments