1010from collections import defaultdict
1111
1212import torch
13- from torch . fx . passes . shape_prop import ShapeProp
13+ from functorch import make_fx
1414from graph_net .torch import utils
1515
1616
@@ -99,13 +99,13 @@ def resolve_get_attr(gm: torch.fx.GraphModule, node: torch.fx.Node):
9999 return out
100100
101101
102- def collect_op_stats (model , input_dict ):
102+ def collect_op_stats_manual (model , input_dict ):
103103 try :
104104 # FX symbolic trace
105105 traced = torch .fx .symbolic_trace (model )
106106 # print(traced.graph)
107107 except Exception :
108- print ("Failed to FX symbolic trace " )
108+ print ("Failed to FX symbolic_trace " )
109109 return False , None
110110
111111 # Use meta tensors as input to avoid actually running the model
@@ -136,7 +136,6 @@ def collect_op_stats(model, input_dict):
136136 )
137137
138138 try :
139- # if True:
140139 if node .op == "call_module" :
141140 # classname of module
142141 submod = traced .get_submodule (node .target )
@@ -200,23 +199,94 @@ def collect_op_stats(model, input_dict):
200199 return is_complete , op_stats
201200
202201
202+ def collect_op_stats_with_make_fx (model , input_dict , arg_types ):
203+ # Use meta tensors as input to avoid actually running the model
204+ meta_input_list = []
205+ for arg_name in arg_types .keys ():
206+ x = input_dict [arg_name ]
207+ meta_x = (
208+ torch .empty_like (x , device = "meta" ) if isinstance (x , torch .Tensor ) else x
209+ )
210+ meta_input_list .append (meta_x )
211+
212+ try :
213+ # Generate FX Graph, and automatically fill in meta information
214+ fx_model = make_fx (model )(* meta_input_list )
215+ except Exception :
216+ print ("Failed to execute make_fx" )
217+ return False , None
218+
219+ is_complete = True
220+ op_stats = {}
221+ for node in fx_model .graph .nodes :
222+ op_name = None
223+ if node .op == "call_module" :
224+ # classname of module
225+ submod = traced .get_submodule (node .target )
226+ op_name = submod .__class__ .__name__
227+ elif node .op == "call_function" :
228+ op_name = node .target .__name__
229+ elif node .op == "call_method" :
230+ op_name = node .target
231+ elif node .op in ["placeholder" , "output" , "get_attr" ]:
232+ op_name = node .op
233+ else :
234+ assert False , f"node.op: { node .op } "
235+
236+ dtype = None
237+ if node .op != "output" :
238+ if "tensor_meta" in node .meta :
239+ tensor_meta = node .meta ["tensor_meta" ]
240+ dtype = tensor_meta .dtype
241+ # print(f"node.op={node.op}, node.target={node.target}, dtype={tensor_meta.dtype}")
242+ else :
243+ print (
244+ f"node.op={ node .op } , node.target={ node .target } has no tensor_meta!"
245+ )
246+ is_complete = False
247+
248+ op_name = (
249+ op_name .replace (".default" , "" )
250+ .replace (".Tensor" , "" )
251+ .replace (".Scalar" , "" )
252+ )
253+ dtype_str = str (dtype ).replace ("torch." , "" )
254+ if op_stats .get (op_name , None ) is None :
255+ op_stats [op_name ] = OpStat (op_name , {dtype_str }, 1 )
256+ else :
257+ op_stats [op_name ].dtype .add (dtype_str )
258+ op_stats [op_name ].count = op_stats [op_name ].count + 1
259+ return is_complete , op_stats
260+
261+
262+ def collect_op_stats (model , input_dict , arg_types ):
263+ is_complete_manual , op_stats_manual = collect_op_stats_manual (model , input_dict )
264+ if not is_complete_manual :
265+ is_complete_make_fx , op_stats_make_fx = collect_op_stats_with_make_fx (
266+ model , input_dict , arg_types
267+ )
268+ if is_complete_make_fx or op_stats_manual is None :
269+ return "make_fx" , is_complete_make_fx , op_stats_make_fx
270+ return "manual" , is_complete_manual , op_stats_manual
271+
272+
203273def collect_model_stats (model_path , device , log_prompt ):
204274 model_class = load_class_from_file (
205275 os .path .join (model_path , "model.py" ), "GraphModule"
206276 )
207277 model = model_class ()
278+ arg_types = get_argument_types (model_class , "forward" )
208279 input_dict = get_input_dict (model_path , device )
209280
210281 num_ops = 0
211- num_inputs = 0
212282 num_outputs = 0
213283 ops_count_info = []
214284 dtypes = set ()
215- is_complete , op_stats = collect_op_stats (model , input_dict )
285+ method , is_complete , op_stats = collect_op_stats (model , input_dict , arg_types )
216286 if op_stats is not None :
217287 for op_name , stat in sorted (op_stats .items ()):
218288 if op_name == "placeholder" :
219- num_inputs += stat . count
289+ pass
220290 elif op_name == "output" :
221291 num_outputs += stat .count
222292 else :
@@ -226,8 +296,7 @@ def collect_model_stats(model_path, device, log_prompt):
226296 if v is not None :
227297 dtypes .add (v )
228298
229- arg_types = get_argument_types (model_class , "forward" )
230- num_inputs = len (arg_types ) if op_stats is None else num_inputs
299+ num_inputs = len (arg_types )
231300 num_params = 0
232301 param_dtypes = set ()
233302 for name , arg_type in arg_types .items ():
@@ -242,7 +311,7 @@ def collect_model_stats(model_path, device, log_prompt):
242311 dtypes_str = "[" + "," .join (dtypes ) + "]"
243312 param_dtypes_str = "[" + "," .join (param_dtypes ) + "]"
244313 print (
245- f"{ log_prompt } [ModelStats] model_path:{ model_path } num_inputs:{ num_inputs } num_outputs:{ num_outputs } num_ops:{ num_ops } num_params:{ num_params_in_billion } B param_dtypes:{ param_dtypes_str } op_dtypes:{ dtypes_str } is_complete:{ is_complete } ops:{ ops_str } " ,
314+ f"{ log_prompt } [ModelStats] model_path:{ model_path } num_inputs:{ num_inputs } num_outputs:{ num_outputs } num_ops:{ num_ops } num_params:{ num_params_in_billion } B param_dtypes:{ param_dtypes_str } op_dtypes:{ dtypes_str } method: { method } is_complete:{ is_complete } ops:{ ops_str } " ,
246315 flush = True ,
247316 )
248317
0 commit comments