@@ -52,7 +52,7 @@ def get_input_dict(model_path, device):
5252@dataclass
5353class OpStat :
5454 op_name : str
55- dtype : set [str ] = field (default_factory = set )
55+ op_dtypes : dict [str , int ] = field (default_factory = dict )
5656 count : int = 0
5757
5858
@@ -124,7 +124,6 @@ def collect_op_stats_manual(model, input_dict):
124124 if node .op == "placeholder" :
125125 node_outputs [node .name ] = meta_input_dict [node .target ]
126126 op_name = node .op
127- dtype = node_outputs [node .name ].dtype
128127 elif node .op in ["call_function" , "call_module" , "call_method" ]:
129128 node_args = torch .fx .map_arg (
130129 node .args ,
@@ -190,11 +189,13 @@ def collect_op_stats_manual(model, input_dict):
190189 assert False , f"node.op: { node .op } "
191190
192191 if op_name is not None :
193- dtype_str = str (dtype ).replace ("torch." , "" ) if dtype is not None else None
192+ dtype_str = str (dtype ).replace ("torch." , "" )
194193 if op_stats .get (op_name , None ) is None :
195- op_stats [op_name ] = OpStat (op_name , {dtype_str }, 1 )
194+ op_stats [op_name ] = OpStat (op_name , {dtype_str : 1 }, 1 )
196195 else :
197- op_stats [op_name ].dtype .add (dtype_str )
196+ op_stats [op_name ].op_dtypes [dtype_str ] = (
197+ op_stats [op_name ].op_dtypes .get (dtype_str , 0 ) + 1
198+ )
198199 op_stats [op_name ].count = op_stats [op_name ].count + 1
199200 return is_complete , op_stats
200201
@@ -234,7 +235,7 @@ def collect_op_stats_with_make_fx(model, input_dict, arg_types):
234235 assert False , f"node.op: { node .op } "
235236
236237 dtype = None
237- if node .op != " output" :
238+ if node .op not in [ "placeholder" , " output"] :
238239 if "tensor_meta" in node .meta :
239240 tensor_meta = node .meta ["tensor_meta" ]
240241 dtype = tensor_meta .dtype
@@ -252,9 +253,11 @@ def collect_op_stats_with_make_fx(model, input_dict, arg_types):
252253 )
253254 dtype_str = str (dtype ).replace ("torch." , "" )
254255 if op_stats .get (op_name , None ) is None :
255- op_stats [op_name ] = OpStat (op_name , {dtype_str }, 1 )
256+ op_stats [op_name ] = OpStat (op_name , {dtype_str : 1 }, 1 )
256257 else :
257- op_stats [op_name ].dtype .add (dtype_str )
258+ op_stats [op_name ].op_dtypes [dtype_str ] = (
259+ op_stats [op_name ].op_dtypes .get (dtype_str , 0 ) + 1
260+ )
258261 op_stats [op_name ].count = op_stats [op_name ].count + 1
259262 return is_complete , op_stats
260263
@@ -280,8 +283,8 @@ def collect_model_stats(model_path, device, log_prompt):
280283
281284 num_ops = 0
282285 num_outputs = 0
283- ops_count_info = []
284- dtypes = set ()
286+ ops_count_dict = {}
287+ op_dtypes = {}
285288 method , is_complete , op_stats = collect_op_stats (model , input_dict , arg_types )
286289 if op_stats is not None :
287290 for op_name , stat in sorted (op_stats .items ()):
@@ -291,29 +294,48 @@ def collect_model_stats(model_path, device, log_prompt):
291294 num_outputs += stat .count
292295 else :
293296 num_ops += stat .count
294- ops_count_info . append ( f" { op_name } = { stat .count } " )
295- for v in stat .dtype :
296- if v is not None :
297- dtypes . add ( v )
297+ ops_count_dict [ op_name ] = stat .count
298+ for dtype_str , num in stat .op_dtypes . items () :
299+ if dtype_str is not None and dtype_str != "None" :
300+ op_dtypes [ dtype_str ] = op_dtypes . get ( dtype_str , 0 ) + num
298301
299- num_inputs = len (arg_types )
300302 num_params = 0
301- param_dtypes = set ()
303+ model_size = 0
304+ input_dtypes = {}
305+ param_dtypes = {}
302306 for name , arg_type in arg_types .items ():
303307 if arg_type == torch .nn .parameter .Parameter :
304- count = math .prod (input_dict [name ].shape )
308+ param_numel = math .prod (input_dict [name ].shape )
305309 # print(f"Parameter {name}: {count}")
306- num_params += count
307- param_dtypes .add (str (input_dict [name ].dtype ).replace ("torch." , "" ))
308- num_params_in_billion = num_params / 1e9
309-
310- ops_str = "[" + "," .join (ops_count_info ) + "]"
311- dtypes_str = "[" + "," .join (dtypes ) + "]"
312- param_dtypes_str = "[" + "," .join (param_dtypes ) + "]"
313- print (
314- f"{ log_prompt } [ModelStats] model_path:{ model_path } num_inputs:{ num_inputs } num_outputs:{ num_outputs } num_ops:{ num_ops } num_params:{ num_params_in_billion } B param_dtypes:{ param_dtypes_str } op_dtypes:{ dtypes_str } method:{ method } is_complete:{ is_complete } ops:{ ops_str } " ,
315- flush = True ,
316- )
310+ num_params += 1
311+ model_size += param_numel
312+ dtype_str = str (input_dict [name ].dtype ).replace ("torch." , "" )
313+ param_dtypes [dtype_str ] = param_dtypes .get (dtype_str , 0 ) + 1
314+ else :
315+ dtype_str = str (input_dict [name ].dtype ).replace ("torch." , "" )
316+ input_dtypes [dtype_str ] = input_dtypes .get (dtype_str , 0 ) + 1
317+ model_size_in_billion = model_size / 1e9
318+ num_inputs = len (arg_types ) - num_params
319+
320+ def dict_to_string (d ):
321+ kv_list = [f"{ k } ={ v } " for k , v in d .items ()]
322+ return "{" + "," .join (kv_list ) + "}"
323+
324+ log_fields = [log_prompt , "[ModelStats]" ]
325+ log_fields .append (f"model_path:{ model_path } " )
326+ log_fields .append (f"num_inputs:{ num_inputs } " )
327+ log_fields .append (f"num_params:{ num_params } " )
328+ log_fields .append (f"num_outputs:{ num_outputs } " )
329+ log_fields .append (f"num_ops:{ num_ops } " )
330+ log_fields .append (f"model_size:{ model_size_in_billion } B" )
331+ log_fields .append (f"input_dtypes:{ dict_to_string (input_dtypes )} " )
332+ log_fields .append (f"param_dtypes:{ dict_to_string (param_dtypes )} " )
333+ log_fields .append (f"op_dtypes:{ dict_to_string (op_dtypes )} " )
334+ log_fields .append (f"ops:{ dict_to_string (ops_count_dict )} " )
335+ log_fields .append (f"method:{ method } " )
336+ log_fields .append (f"is_complete:{ is_complete } " )
337+
338+ print (" " .join (log_fields ), flush = True )
317339
318340
319341def main (args ):
0 commit comments