@@ -52,7 +52,7 @@ def get_input_dict(model_path, device):
52
52
@dataclass
53
53
class OpStat :
54
54
op_name : str
55
- dtype : set [str ] = field (default_factory = set )
55
+ op_dtypes : dict [str , int ] = field (default_factory = dict )
56
56
count : int = 0
57
57
58
58
@@ -124,7 +124,6 @@ def collect_op_stats_manual(model, input_dict):
124
124
if node .op == "placeholder" :
125
125
node_outputs [node .name ] = meta_input_dict [node .target ]
126
126
op_name = node .op
127
- dtype = node_outputs [node .name ].dtype
128
127
elif node .op in ["call_function" , "call_module" , "call_method" ]:
129
128
node_args = torch .fx .map_arg (
130
129
node .args ,
@@ -190,11 +189,13 @@ def collect_op_stats_manual(model, input_dict):
190
189
assert False , f"node.op: { node .op } "
191
190
192
191
if op_name is not None :
193
- dtype_str = str (dtype ).replace ("torch." , "" ) if dtype is not None else None
192
+ dtype_str = str (dtype ).replace ("torch." , "" )
194
193
if op_stats .get (op_name , None ) is None :
195
- op_stats [op_name ] = OpStat (op_name , {dtype_str }, 1 )
194
+ op_stats [op_name ] = OpStat (op_name , {dtype_str : 1 }, 1 )
196
195
else :
197
- op_stats [op_name ].dtype .add (dtype_str )
196
+ op_stats [op_name ].op_dtypes [dtype_str ] = (
197
+ op_stats [op_name ].op_dtypes .get (dtype_str , 0 ) + 1
198
+ )
198
199
op_stats [op_name ].count = op_stats [op_name ].count + 1
199
200
return is_complete , op_stats
200
201
@@ -234,7 +235,7 @@ def collect_op_stats_with_make_fx(model, input_dict, arg_types):
234
235
assert False , f"node.op: { node .op } "
235
236
236
237
dtype = None
237
- if node .op != " output" :
238
+ if node .op not in [ "placeholder" , " output"] :
238
239
if "tensor_meta" in node .meta :
239
240
tensor_meta = node .meta ["tensor_meta" ]
240
241
dtype = tensor_meta .dtype
@@ -252,9 +253,11 @@ def collect_op_stats_with_make_fx(model, input_dict, arg_types):
252
253
)
253
254
dtype_str = str (dtype ).replace ("torch." , "" )
254
255
if op_stats .get (op_name , None ) is None :
255
- op_stats [op_name ] = OpStat (op_name , {dtype_str }, 1 )
256
+ op_stats [op_name ] = OpStat (op_name , {dtype_str : 1 }, 1 )
256
257
else :
257
- op_stats [op_name ].dtype .add (dtype_str )
258
+ op_stats [op_name ].op_dtypes [dtype_str ] = (
259
+ op_stats [op_name ].op_dtypes .get (dtype_str , 0 ) + 1
260
+ )
258
261
op_stats [op_name ].count = op_stats [op_name ].count + 1
259
262
return is_complete , op_stats
260
263
@@ -280,8 +283,8 @@ def collect_model_stats(model_path, device, log_prompt):
280
283
281
284
num_ops = 0
282
285
num_outputs = 0
283
- ops_count_info = []
284
- dtypes = set ()
286
+ ops_count_dict = {}
287
+ op_dtypes = {}
285
288
method , is_complete , op_stats = collect_op_stats (model , input_dict , arg_types )
286
289
if op_stats is not None :
287
290
for op_name , stat in sorted (op_stats .items ()):
@@ -291,29 +294,48 @@ def collect_model_stats(model_path, device, log_prompt):
291
294
num_outputs += stat .count
292
295
else :
293
296
num_ops += stat .count
294
- ops_count_info . append ( f" { op_name } = { stat .count } " )
295
- for v in stat .dtype :
296
- if v is not None :
297
- dtypes . add ( v )
297
+ ops_count_dict [ op_name ] = stat .count
298
+ for dtype_str , num in stat .op_dtypes . items () :
299
+ if dtype_str is not None and dtype_str != "None" :
300
+ op_dtypes [ dtype_str ] = op_dtypes . get ( dtype_str , 0 ) + num
298
301
299
- num_inputs = len (arg_types )
300
302
num_params = 0
301
- param_dtypes = set ()
303
+ model_size = 0
304
+ input_dtypes = {}
305
+ param_dtypes = {}
302
306
for name , arg_type in arg_types .items ():
303
307
if arg_type == torch .nn .parameter .Parameter :
304
- count = math .prod (input_dict [name ].shape )
308
+ param_numel = math .prod (input_dict [name ].shape )
305
309
# print(f"Parameter {name}: {count}")
306
- num_params += count
307
- param_dtypes .add (str (input_dict [name ].dtype ).replace ("torch." , "" ))
308
- num_params_in_billion = num_params / 1e9
309
-
310
- ops_str = "[" + "," .join (ops_count_info ) + "]"
311
- dtypes_str = "[" + "," .join (dtypes ) + "]"
312
- param_dtypes_str = "[" + "," .join (param_dtypes ) + "]"
313
- print (
314
- f"{ log_prompt } [ModelStats] model_path:{ model_path } num_inputs:{ num_inputs } num_outputs:{ num_outputs } num_ops:{ num_ops } num_params:{ num_params_in_billion } B param_dtypes:{ param_dtypes_str } op_dtypes:{ dtypes_str } method:{ method } is_complete:{ is_complete } ops:{ ops_str } " ,
315
- flush = True ,
316
- )
310
+ num_params += 1
311
+ model_size += param_numel
312
+ dtype_str = str (input_dict [name ].dtype ).replace ("torch." , "" )
313
+ param_dtypes [dtype_str ] = param_dtypes .get (dtype_str , 0 ) + 1
314
+ else :
315
+ dtype_str = str (input_dict [name ].dtype ).replace ("torch." , "" )
316
+ input_dtypes [dtype_str ] = input_dtypes .get (dtype_str , 0 ) + 1
317
+ model_size_in_billion = model_size / 1e9
318
+ num_inputs = len (arg_types ) - num_params
319
+
320
+ def dict_to_string (d ):
321
+ kv_list = [f"{ k } ={ v } " for k , v in d .items ()]
322
+ return "{" + "," .join (kv_list ) + "}"
323
+
324
+ log_fields = [log_prompt , "[ModelStats]" ]
325
+ log_fields .append (f"model_path:{ model_path } " )
326
+ log_fields .append (f"num_inputs:{ num_inputs } " )
327
+ log_fields .append (f"num_params:{ num_params } " )
328
+ log_fields .append (f"num_outputs:{ num_outputs } " )
329
+ log_fields .append (f"num_ops:{ num_ops } " )
330
+ log_fields .append (f"model_size:{ model_size_in_billion } B" )
331
+ log_fields .append (f"input_dtypes:{ dict_to_string (input_dtypes )} " )
332
+ log_fields .append (f"param_dtypes:{ dict_to_string (param_dtypes )} " )
333
+ log_fields .append (f"op_dtypes:{ dict_to_string (op_dtypes )} " )
334
+ log_fields .append (f"ops:{ dict_to_string (ops_count_dict )} " )
335
+ log_fields .append (f"method:{ method } " )
336
+ log_fields .append (f"is_complete:{ is_complete } " )
337
+
338
+ print (" " .join (log_fields ), flush = True )
317
339
318
340
319
341
def main (args ):
0 commit comments