22import os
33import re
44import sys
5- import ast
65import math
7- import importlib
8- import inspect
96import subprocess
107from datetime import datetime
11- from typing import Type
12- from dataclasses import dataclass , field
13- from collections import defaultdict
148
159import paddle
10+ from graph_net import collect_stats_util
1611from graph_net .paddle import utils
1712
1813
1914def is_single_model_dir (model_dir ):
2015 return os .path .isfile (f"{ model_dir } /graph_net.json" )
2116
2217
23- def load_class_from_file (file_path : str , class_name : str ) -> Type [paddle .nn .Layer ]:
24- spec = importlib .util .spec_from_file_location ("unnamed" , file_path )
25- unnamed = importlib .util .module_from_spec (spec )
26- spec .loader .exec_module (unnamed )
27- model_class = getattr (unnamed , class_name , None )
28- return model_class
29-
30-
31- def get_argument_name_and_types (model_class , func_name ):
32- argument_name2types = {}
33- for name , func in inspect .getmembers (model_class , predicate = inspect .isfunction ):
34- if name == func_name :
35- for arg_name , arg in inspect .signature (func ).parameters .items ():
36- if arg_name != "self" :
37- argument_name2types [arg_name ] = (
38- None if arg .annotation is inspect ._empty else arg .annotation
39- )
40- return argument_name2types
41-
42-
43- def get_number_of_returns (file_path , class_name , func_name ):
44- source = None
45- with open (f"{ file_path } " , "r" ) as f :
46- source = f .read ()
47-
48- tree = ast .parse (source )
49- for node in tree .body :
50- if isinstance (node , ast .ClassDef ) and node .name == class_name :
51- for f in node .body :
52- if isinstance (f , ast .FunctionDef ) and f .name == func_name :
53- for stmt in ast .walk (f ):
54- if isinstance (stmt , ast .Return ):
55- if stmt .value is None :
56- return 0
57- elif isinstance (stmt .value , ast .Tuple ):
58- return len (stmt .value .elts )
59- else :
60- return 1
61- return 0
62-
63-
6418def read_graph_source_and_tag (model_path ):
6519 try :
6620 with open (os .path .join (model_path , "graph_net.json" ), "r" ) as f :
@@ -87,19 +41,6 @@ def get_input_spec(model_path):
8741 return input_spec
8842
8943
90- @dataclass
91- class OpStat :
92- op_name : str
93- op_dtypes : dict [str , int ] = field (default_factory = dict )
94- count : int = 0
95-
96- def update (self , other ):
97- if isinstance (other , OpStat ) and self .op_name == other .op_name :
98- self .count += other .count
99- for name , count in other .op_dtypes .items ():
100- self .op_dtypes [name ] = self .op_dtypes .get (name , 0 ) + count
101-
102-
10344class ProgramAnalyzer :
10445 def __init__ (self ):
10546 self .op_stats = {}
@@ -112,7 +53,9 @@ def update_op_stats(self, op_name, op_dtype):
11253 if op_name is not None :
11354 dtype_str = str (op_dtype ).replace ("paddle." , "" )
11455 if self .op_stats .get (op_name , None ) is None :
115- self .op_stats [op_name ] = OpStat (op_name , {dtype_str : 1 }, 1 )
56+ self .op_stats [op_name ] = collect_stats_util .OpStat (
57+ op_name , {dtype_str : 1 }, 1
58+ )
11659 else :
11760 self .op_stats [op_name ].op_dtypes [dtype_str ] = (
11861 self .op_stats [op_name ].op_dtypes .get (dtype_str , 0 ) + 1
@@ -213,9 +156,8 @@ def collect_op_stats(model, model_path):
213156
214157def collect_model_stats (model_path , log_prompt ):
215158 file_path = os .path .join (model_path , "model.py" )
216- model_class = load_class_from_file (file_path , "GraphModule" )
159+ model_class = collect_stats_util . load_class_from_file (file_path , "GraphModule" )
217160 model = model_class ()
218- num_outputs = get_number_of_returns (file_path , "GraphModule" , "forward" )
219161
220162 model_size = 0
221163 input_dtypes = {}
@@ -244,39 +186,33 @@ def collect_model_stats(model_path, log_prompt):
244186 elif name in inputs .keys ():
245187 input_dtypes [dtype_str ] = input_dtypes .get (dtype_str , 0 ) + 1
246188
247- model_size_in_billion = model_size / 1e9
248- num_params = sum ( param_dtypes . values ())
249- num_inputs = sum ( input_dtypes . values () )
250- num_ops = sum ( ops_count_dict . values ())
189+ num_outputs = collect_stats_util . get_number_of_returns (
190+ file_path , "GraphModule" , "forward"
191+ )
192+ num_ops = program_analyzer . num_ops if program_analyzer is not None else 0
251193 source , heuristic_tag = read_graph_source_and_tag (model_path )
252- method = "to_static"
253194 is_complete = (
254195 program_analyzer .is_complete if program_analyzer is not None else False
255196 )
197+ print (
198+ f"model_stats collection information: model_path={ model_path } , method=to_static, is_ops_complete={ is_complete } "
199+ )
256200
257- def dict_to_string (d ):
258- kv_list = [f"{ k } :{ v } " for k , v in d .items ()]
259- return " " .join (kv_list )
260-
261- def print_with_log_prompt (key , value ):
262- print (
263- f"{ log_prompt } [ModelStats.{ key } ] model_path:{ model_path } { value } " ,
264- flush = True ,
265- )
266-
267- print_with_log_prompt ("num_inputs" , num_inputs )
268- print_with_log_prompt ("num_params" , num_params )
269- print_with_log_prompt ("num_outputs" , num_outputs )
270- print_with_log_prompt ("num_ops" , num_ops )
271- print_with_log_prompt ("model_size" , f"{ model_size_in_billion } B" )
272- print_with_log_prompt ("input_dtypes" , dict_to_string (input_dtypes ))
273- print_with_log_prompt ("param_dtypes" , dict_to_string (param_dtypes ))
274- print_with_log_prompt ("op_dtypes" , dict_to_string (op_dtypes ))
275- print_with_log_prompt ("ops" , dict_to_string (ops_count_dict ))
276- print_with_log_prompt ("source" , source )
277- print_with_log_prompt ("heuristic_tag" , heuristic_tag )
278- print_with_log_prompt ("method" , method )
279- print_with_log_prompt ("is_complete" , is_complete )
201+ stats = collect_stats_util .ModelStats (
202+ model_path = model_path ,
203+ num_inputs = sum (input_dtypes .values ()),
204+ num_params = sum (param_dtypes .values ()),
205+ num_outputs = num_outputs ,
206+ num_ops = num_ops ,
207+ model_size_in_billion = model_size / 1e9 ,
208+ input_dtypes = input_dtypes ,
209+ param_dtypes = param_dtypes ,
210+ op_dtypes = op_dtypes ,
211+ ops = ops_count_dict ,
212+ source = source ,
213+ heuristic_tag = heuristic_tag ,
214+ )
215+ collect_stats_util .print_model_stats (stats , log_prompt )
280216
281217
282218def main (args ):
@@ -295,23 +231,9 @@ def main(args):
295231 else args .graph_net_samples_path
296232 )
297233
298- previous_failed_model_pathes = []
299- if args .previous_collect_result_path is not None :
300- with open (args .previous_collect_result_path , "r" ) as f :
301- for line in f .readlines ():
302- if "[ModelStats]" in line :
303- fields = line .strip ().split ()
304- model_path = fields [2 ].split (":" )[- 1 ]
305- is_complete = fields [- 1 ].split (":" )[- 1 ]
306- if is_complete == "False" :
307- previous_failed_model_pathes .append (model_path )
308-
309234 i = 0
310235 for root , dirs , files in os .walk (graph_net_samples_path ):
311- if is_single_model_dir (root ) and (
312- args .previous_collect_result_path is None
313- or root in previous_failed_model_pathes
314- ):
236+ if is_single_model_dir (root ):
315237 print (f"[{ i } ] Collect information for { root } " )
316238 cmd = [
317239 "python" ,
@@ -359,13 +281,6 @@ def main(args):
359281 default = None ,
360282 help = "GraphNet samples directory. e.g '../../paddle_samples'" ,
361283 )
362- parser .add_argument (
363- "--previous-collect-result-path" ,
364- type = str ,
365- required = False ,
366- default = None ,
367- help = "Previous collect result path, use to recollect the failed cases" ,
368- )
369284 parser .add_argument (
370285 "--log-prompt" ,
371286 type = str ,
0 commit comments