10
10
from collections import defaultdict
11
11
12
12
import torch
13
- from torch . fx . passes . shape_prop import ShapeProp
13
+ from functorch import make_fx
14
14
from graph_net .torch import utils
15
15
16
16
@@ -99,13 +99,13 @@ def resolve_get_attr(gm: torch.fx.GraphModule, node: torch.fx.Node):
99
99
return out
100
100
101
101
102
- def collect_op_stats (model , input_dict ):
102
+ def collect_op_stats_manual (model , input_dict ):
103
103
try :
104
104
# FX symbolic trace
105
105
traced = torch .fx .symbolic_trace (model )
106
106
# print(traced.graph)
107
107
except Exception :
108
- print ("Failed to FX symbolic trace " )
108
+ print ("Failed to FX symbolic_trace " )
109
109
return False , None
110
110
111
111
# Use meta tensors as input to avoid actually running the model
@@ -136,7 +136,6 @@ def collect_op_stats(model, input_dict):
136
136
)
137
137
138
138
try :
139
- # if True:
140
139
if node .op == "call_module" :
141
140
# classname of module
142
141
submod = traced .get_submodule (node .target )
@@ -200,23 +199,94 @@ def collect_op_stats(model, input_dict):
200
199
return is_complete , op_stats
201
200
202
201
202
+ def collect_op_stats_with_make_fx (model , input_dict , arg_types ):
203
+ # Use meta tensors as input to avoid actually running the model
204
+ meta_input_list = []
205
+ for arg_name in arg_types .keys ():
206
+ x = input_dict [arg_name ]
207
+ meta_x = (
208
+ torch .empty_like (x , device = "meta" ) if isinstance (x , torch .Tensor ) else x
209
+ )
210
+ meta_input_list .append (meta_x )
211
+
212
+ try :
213
+ # Generate FX Graph, and automatically fill in meta information
214
+ fx_model = make_fx (model )(* meta_input_list )
215
+ except Exception :
216
+ print ("Failed to execute make_fx" )
217
+ return False , None
218
+
219
+ is_complete = True
220
+ op_stats = {}
221
+ for node in fx_model .graph .nodes :
222
+ op_name = None
223
+ if node .op == "call_module" :
224
+ # classname of module
225
+ submod = traced .get_submodule (node .target )
226
+ op_name = submod .__class__ .__name__
227
+ elif node .op == "call_function" :
228
+ op_name = node .target .__name__
229
+ elif node .op == "call_method" :
230
+ op_name = node .target
231
+ elif node .op in ["placeholder" , "output" , "get_attr" ]:
232
+ op_name = node .op
233
+ else :
234
+ assert False , f"node.op: { node .op } "
235
+
236
+ dtype = None
237
+ if node .op != "output" :
238
+ if "tensor_meta" in node .meta :
239
+ tensor_meta = node .meta ["tensor_meta" ]
240
+ dtype = tensor_meta .dtype
241
+ # print(f"node.op={node.op}, node.target={node.target}, dtype={tensor_meta.dtype}")
242
+ else :
243
+ print (
244
+ f"node.op={ node .op } , node.target={ node .target } has no tensor_meta!"
245
+ )
246
+ is_complete = False
247
+
248
+ op_name = (
249
+ op_name .replace (".default" , "" )
250
+ .replace (".Tensor" , "" )
251
+ .replace (".Scalar" , "" )
252
+ )
253
+ dtype_str = str (dtype ).replace ("torch." , "" )
254
+ if op_stats .get (op_name , None ) is None :
255
+ op_stats [op_name ] = OpStat (op_name , {dtype_str }, 1 )
256
+ else :
257
+ op_stats [op_name ].dtype .add (dtype_str )
258
+ op_stats [op_name ].count = op_stats [op_name ].count + 1
259
+ return is_complete , op_stats
260
+
261
+
262
+ def collect_op_stats (model , input_dict , arg_types ):
263
+ is_complete_manual , op_stats_manual = collect_op_stats_manual (model , input_dict )
264
+ if not is_complete_manual :
265
+ is_complete_make_fx , op_stats_make_fx = collect_op_stats_with_make_fx (
266
+ model , input_dict , arg_types
267
+ )
268
+ if is_complete_make_fx or op_stats_manual is None :
269
+ return "make_fx" , is_complete_make_fx , op_stats_make_fx
270
+ return "manual" , is_complete_manual , op_stats_manual
271
+
272
+
203
273
def collect_model_stats (model_path , device , log_prompt ):
204
274
model_class = load_class_from_file (
205
275
os .path .join (model_path , "model.py" ), "GraphModule"
206
276
)
207
277
model = model_class ()
278
+ arg_types = get_argument_types (model_class , "forward" )
208
279
input_dict = get_input_dict (model_path , device )
209
280
210
281
num_ops = 0
211
- num_inputs = 0
212
282
num_outputs = 0
213
283
ops_count_info = []
214
284
dtypes = set ()
215
- is_complete , op_stats = collect_op_stats (model , input_dict )
285
+ method , is_complete , op_stats = collect_op_stats (model , input_dict , arg_types )
216
286
if op_stats is not None :
217
287
for op_name , stat in sorted (op_stats .items ()):
218
288
if op_name == "placeholder" :
219
- num_inputs += stat . count
289
+ pass
220
290
elif op_name == "output" :
221
291
num_outputs += stat .count
222
292
else :
@@ -226,8 +296,7 @@ def collect_model_stats(model_path, device, log_prompt):
226
296
if v is not None :
227
297
dtypes .add (v )
228
298
229
- arg_types = get_argument_types (model_class , "forward" )
230
- num_inputs = len (arg_types ) if op_stats is None else num_inputs
299
+ num_inputs = len (arg_types )
231
300
num_params = 0
232
301
param_dtypes = set ()
233
302
for name , arg_type in arg_types .items ():
@@ -242,7 +311,7 @@ def collect_model_stats(model_path, device, log_prompt):
242
311
dtypes_str = "[" + "," .join (dtypes ) + "]"
243
312
param_dtypes_str = "[" + "," .join (param_dtypes ) + "]"
244
313
print (
245
- f"{ log_prompt } [ModelStats] model_path:{ model_path } num_inputs:{ num_inputs } num_outputs:{ num_outputs } num_ops:{ num_ops } num_params:{ num_params_in_billion } B param_dtypes:{ param_dtypes_str } op_dtypes:{ dtypes_str } is_complete:{ is_complete } ops:{ ops_str } " ,
314
+ f"{ log_prompt } [ModelStats] model_path:{ model_path } num_inputs:{ num_inputs } num_outputs:{ num_outputs } num_ops:{ num_ops } num_params:{ num_params_in_billion } B param_dtypes:{ param_dtypes_str } op_dtypes:{ dtypes_str } method: { method } is_complete:{ is_complete } ops:{ ops_str } " ,
246
315
flush = True ,
247
316
)
248
317
0 commit comments