4
4
import math
5
5
import importlib
6
6
import inspect
7
+ import subprocess
7
8
from typing import Type
8
9
from dataclasses import dataclass , field
9
10
from collections import defaultdict
@@ -62,11 +63,31 @@ def resolve_native_multi_head_attention(*args, **kwargs):
62
63
(seq_len , batch_size , embed_dim ), dtype = query .dtype , device = "meta"
63
64
)
64
65
65
- # seq_len_k = key.shape[0]
66
- # num_heads = args[4]
67
- # attn_output_weights = torch.empty((batch_size, num_heads, seq_len, seq_len_k),
68
- # dtype=query.dtype, device='meta')
69
- return attn_output # , attn_output_weights
66
+ # TODO(Xreki): get value from args
67
+ need_weights = False
68
+ if need_weights :
69
+ seq_len_k = key .shape [0 ]
70
+ num_heads = args [4 ]
71
+ attn_output_weights = torch .empty (
72
+ (batch_size , num_heads , seq_len , seq_len_k ),
73
+ dtype = query .dtype ,
74
+ device = "meta" ,
75
+ )
76
+ return attn_output , attn_output_weights
77
+ else :
78
+ return attn_output
79
+
80
+
81
+ def resolve_tensor_to (tensor , * args , ** kwargs ):
82
+ if isinstance (args [0 ], torch .dtype ):
83
+ dtype = args [0 ]
84
+ else :
85
+ dtype = tensor .dtype
86
+ return torch .empty (tensor .shape , dtype = dtype , device = "meta" )
87
+
88
+
89
+ def resolve_tensor_item (tensor ):
90
+ return torch .empty ((), dtype = tensor .dtype , device = "meta" )
70
91
71
92
72
93
def resolve_get_attr (gm : torch .fx .GraphModule , node : torch .fx .Node ):
@@ -115,6 +136,7 @@ def collect_op_stats(model, input_dict):
115
136
)
116
137
117
138
try :
139
+ # if True:
118
140
if node .op == "call_module" :
119
141
# classname of module
120
142
submod = traced .get_submodule (node .target )
@@ -133,8 +155,15 @@ def collect_op_stats(model, input_dict):
133
155
op_func = getattr (self_obj , node .target )
134
156
node_args = node_args [1 :]
135
157
158
+ # print(f"node.op={node.op}, op_name={op_name}, node.args={node.args}")
136
159
if op_name == "_native_multi_head_attention" :
137
160
out = resolve_native_multi_head_attention (* node_args , ** node_kwargs )
161
+ elif op_name == "to" :
162
+ out = resolve_tensor_to (
163
+ node_outputs [node .args [0 ].name ], * node_args , ** node_kwargs
164
+ )
165
+ elif op_name == "item" :
166
+ out = resolve_tensor_item (node_outputs [node .args [0 ].name ])
138
167
else :
139
168
out = op_func (* node_args , ** node_kwargs )
140
169
node_outputs [node .name ] = out
@@ -172,12 +201,6 @@ def collect_op_stats(model, input_dict):
172
201
173
202
174
203
def collect_model_stats (model_path , device , log_prompt ):
175
- if not hasattr (collect_model_stats , "_counter" ):
176
- collect_model_stats ._counter = 0
177
- else :
178
- collect_model_stats ._counter += 1
179
- print (f"[{ collect_model_stats ._counter } ] Collect information for { model_path } " )
180
-
181
204
model_class = load_class_from_file (
182
205
os .path .join (model_path , "model.py" ), "GraphModule"
183
206
)
@@ -187,16 +210,18 @@ def collect_model_stats(model_path, device, log_prompt):
187
210
num_ops = 0
188
211
num_inputs = 0
189
212
num_outputs = 0
213
+ ops_count_info = []
190
214
dtypes = set ()
191
215
is_complete , op_stats = collect_op_stats (model , input_dict )
192
216
if op_stats is not None :
193
- for op_name , stat in op_stats .items ():
217
+ for op_name , stat in sorted ( op_stats .items () ):
194
218
if op_name == "placeholder" :
195
219
num_inputs += stat .count
196
220
elif op_name == "output" :
197
221
num_outputs += stat .count
198
222
else :
199
223
num_ops += stat .count
224
+ ops_count_info .append (f"{ op_name } ={ stat .count } " )
200
225
for v in stat .dtype :
201
226
if v is not None :
202
227
dtypes .add (v )
@@ -213,11 +238,11 @@ def collect_model_stats(model_path, device, log_prompt):
213
238
param_dtypes .add (str (input_dict [name ].dtype ).replace ("torch." , "" ))
214
239
num_params_in_billion = num_params / 1e9
215
240
241
+ ops_str = "[" + "," .join (ops_count_info ) + "]"
216
242
dtypes_str = "[" + "," .join (dtypes ) + "]"
217
243
param_dtypes_str = "[" + "," .join (param_dtypes ) + "]"
218
244
print (
219
- f"{ log_prompt } [ModelStats] model_path:{ model_path } num_inputs:{ num_inputs } num_outputs:{ num_outputs } num_ops:{ num_ops } num_params:{ num_params_in_billion } B param_dtypes:{ param_dtypes_str } op_dtypes:{ dtypes_str } is_complete:{ is_complete } " ,
220
- file = sys .stderr ,
245
+ f"{ log_prompt } [ModelStats] model_path:{ model_path } num_inputs:{ num_inputs } num_outputs:{ num_outputs } num_ops:{ num_ops } num_params:{ num_params_in_billion } B param_dtypes:{ param_dtypes_str } op_dtypes:{ dtypes_str } is_complete:{ is_complete } ops:{ ops_str } " ,
221
246
flush = True ,
222
247
)
223
248
@@ -226,16 +251,36 @@ def main(args):
226
251
if args .model_path is not None :
227
252
assert os .path .isdir (args .model_path )
228
253
assert is_single_model_dir (args .model_path )
254
+ print (f"Collect information for { args .model_path } " )
229
255
collect_model_stats (args .model_path , args .device , args .log_prompt )
230
256
else :
231
257
graph_net_samples_path = (
232
258
(graph_net .torch .samples_util .get_default_samples_directory ())
233
259
if args .graph_net_samples_path is None
234
260
else args .graph_net_samples_path
235
261
)
262
+ i = 0
236
263
for root , dirs , files in os .walk (graph_net_samples_path ):
237
264
if is_single_model_dir (root ):
238
- collect_model_stats (root , args .device , args .log_prompt )
265
+ print (f"[{ i } ] Collect information for { root } " )
266
+ cmd = [
267
+ "python" ,
268
+ "-m" ,
269
+ "graph_net.torch.collect_stats" ,
270
+ f"--device={ args .device } " ,
271
+ f"--model-path={ root } " ,
272
+ f"--log-prompt={ args .log_prompt } " ,
273
+ ]
274
+ result = subprocess .run (
275
+ cmd ,
276
+ stdout = subprocess .PIPE ,
277
+ stderr = subprocess .PIPE ,
278
+ text = True ,
279
+ timeout = 600 ,
280
+ )
281
+ if result .returncode == 0 :
282
+ print (result .stdout )
283
+ i += 1
239
284
240
285
241
286
if __name__ == "__main__" :
0 commit comments