44import math
55import importlib
66import inspect
7+ import subprocess
78from typing import Type
89from dataclasses import dataclass , field
910from collections import defaultdict
@@ -62,11 +63,31 @@ def resolve_native_multi_head_attention(*args, **kwargs):
6263 (seq_len , batch_size , embed_dim ), dtype = query .dtype , device = "meta"
6364 )
6465
65- # seq_len_k = key.shape[0]
66- # num_heads = args[4]
67- # attn_output_weights = torch.empty((batch_size, num_heads, seq_len, seq_len_k),
68- # dtype=query.dtype, device='meta')
69- return attn_output # , attn_output_weights
66+ # TODO(Xreki): get value from args
67+ need_weights = False
68+ if need_weights :
69+ seq_len_k = key .shape [0 ]
70+ num_heads = args [4 ]
71+ attn_output_weights = torch .empty (
72+ (batch_size , num_heads , seq_len , seq_len_k ),
73+ dtype = query .dtype ,
74+ device = "meta" ,
75+ )
76+ return attn_output , attn_output_weights
77+ else :
78+ return attn_output
79+
80+
81+ def resolve_tensor_to (tensor , * args , ** kwargs ):
82+ if isinstance (args [0 ], torch .dtype ):
83+ dtype = args [0 ]
84+ else :
85+ dtype = tensor .dtype
86+ return torch .empty (tensor .shape , dtype = dtype , device = "meta" )
87+
88+
89+ def resolve_tensor_item (tensor ):
90+ return torch .empty ((), dtype = tensor .dtype , device = "meta" )
7091
7192
7293def resolve_get_attr (gm : torch .fx .GraphModule , node : torch .fx .Node ):
@@ -115,6 +136,7 @@ def collect_op_stats(model, input_dict):
115136 )
116137
117138 try :
139+ # if True:
118140 if node .op == "call_module" :
119141 # classname of module
120142 submod = traced .get_submodule (node .target )
@@ -133,8 +155,15 @@ def collect_op_stats(model, input_dict):
133155 op_func = getattr (self_obj , node .target )
134156 node_args = node_args [1 :]
135157
158+ # print(f"node.op={node.op}, op_name={op_name}, node.args={node.args}")
136159 if op_name == "_native_multi_head_attention" :
137160 out = resolve_native_multi_head_attention (* node_args , ** node_kwargs )
161+ elif op_name == "to" :
162+ out = resolve_tensor_to (
163+ node_outputs [node .args [0 ].name ], * node_args , ** node_kwargs
164+ )
165+ elif op_name == "item" :
166+ out = resolve_tensor_item (node_outputs [node .args [0 ].name ])
138167 else :
139168 out = op_func (* node_args , ** node_kwargs )
140169 node_outputs [node .name ] = out
@@ -172,12 +201,6 @@ def collect_op_stats(model, input_dict):
172201
173202
174203def collect_model_stats (model_path , device , log_prompt ):
175- if not hasattr (collect_model_stats , "_counter" ):
176- collect_model_stats ._counter = 0
177- else :
178- collect_model_stats ._counter += 1
179- print (f"[{ collect_model_stats ._counter } ] Collect information for { model_path } " )
180-
181204 model_class = load_class_from_file (
182205 os .path .join (model_path , "model.py" ), "GraphModule"
183206 )
@@ -187,16 +210,18 @@ def collect_model_stats(model_path, device, log_prompt):
187210 num_ops = 0
188211 num_inputs = 0
189212 num_outputs = 0
213+ ops_count_info = []
190214 dtypes = set ()
191215 is_complete , op_stats = collect_op_stats (model , input_dict )
192216 if op_stats is not None :
193- for op_name , stat in op_stats .items ():
217+ for op_name , stat in sorted ( op_stats .items () ):
194218 if op_name == "placeholder" :
195219 num_inputs += stat .count
196220 elif op_name == "output" :
197221 num_outputs += stat .count
198222 else :
199223 num_ops += stat .count
224+ ops_count_info .append (f"{ op_name } ={ stat .count } " )
200225 for v in stat .dtype :
201226 if v is not None :
202227 dtypes .add (v )
@@ -213,11 +238,11 @@ def collect_model_stats(model_path, device, log_prompt):
213238 param_dtypes .add (str (input_dict [name ].dtype ).replace ("torch." , "" ))
214239 num_params_in_billion = num_params / 1e9
215240
241+ ops_str = "[" + "," .join (ops_count_info ) + "]"
216242 dtypes_str = "[" + "," .join (dtypes ) + "]"
217243 param_dtypes_str = "[" + "," .join (param_dtypes ) + "]"
218244 print (
219- f"{ log_prompt } [ModelStats] model_path:{ model_path } num_inputs:{ num_inputs } num_outputs:{ num_outputs } num_ops:{ num_ops } num_params:{ num_params_in_billion } B param_dtypes:{ param_dtypes_str } op_dtypes:{ dtypes_str } is_complete:{ is_complete } " ,
220- file = sys .stderr ,
245+ f"{ log_prompt } [ModelStats] model_path:{ model_path } num_inputs:{ num_inputs } num_outputs:{ num_outputs } num_ops:{ num_ops } num_params:{ num_params_in_billion } B param_dtypes:{ param_dtypes_str } op_dtypes:{ dtypes_str } is_complete:{ is_complete } ops:{ ops_str } " ,
221246 flush = True ,
222247 )
223248
@@ -226,16 +251,36 @@ def main(args):
226251 if args .model_path is not None :
227252 assert os .path .isdir (args .model_path )
228253 assert is_single_model_dir (args .model_path )
254+ print (f"Collect information for { args .model_path } " )
229255 collect_model_stats (args .model_path , args .device , args .log_prompt )
230256 else :
231257 graph_net_samples_path = (
232258 (graph_net .torch .samples_util .get_default_samples_directory ())
233259 if args .graph_net_samples_path is None
234260 else args .graph_net_samples_path
235261 )
262+ i = 0
236263 for root , dirs , files in os .walk (graph_net_samples_path ):
237264 if is_single_model_dir (root ):
238- collect_model_stats (root , args .device , args .log_prompt )
265+ print (f"[{ i } ] Collect information for { root } " )
266+ cmd = [
267+ "python" ,
268+ "-m" ,
269+ "graph_net.torch.collect_stats" ,
270+ f"--device={ args .device } " ,
271+ f"--model-path={ root } " ,
272+ f"--log-prompt={ args .log_prompt } " ,
273+ ]
274+ result = subprocess .run (
275+ cmd ,
276+ stdout = subprocess .PIPE ,
277+ stderr = subprocess .PIPE ,
278+ text = True ,
279+ timeout = 600 ,
280+ )
281+ if result .returncode == 0 :
282+ print (result .stdout )
283+ i += 1
239284
240285
241286if __name__ == "__main__" :
0 commit comments