@@ -55,6 +55,12 @@ class OpStat:
55
55
op_dtypes : dict [str , int ] = field (default_factory = dict )
56
56
count : int = 0
57
57
58
+ def update (self , other ):
59
+ if isinstance (other , OpStat ) and self .op_name == other .op_name :
60
+ self .count += other .count
61
+ for name , count in other .op_dtypes .items ():
62
+ self .op_dtypes [name ] = self .op_dtypes .get (name , 0 ) + count
63
+
58
64
59
65
def resolve_native_multi_head_attention (* args , ** kwargs ):
60
66
query , key , value = args [0 ], args [1 ], args [2 ]
@@ -132,19 +138,23 @@ def resolve_with_real_tensor(op_func, device, meta_args, meta_kwargs):
132
138
return None
133
139
134
140
135
- def collect_op_stats_manual (model , input_dict , device ):
136
- try :
137
- # FX symbolic trace
138
- traced = torch .fx .symbolic_trace (model )
139
- # print(traced.graph)
140
- except Exception :
141
- print ("Failed to FX symbolic_trace" )
142
- return False , None
141
+ torch ._dynamo .config .capture_scalar_outputs = True
142
+ torch ._dynamo .config .capture_dynamic_output_shape_ops = True
143
+ torch ._dynamo .config .capture_sparse_compute = True
144
+ torch ._dynamo .config .raise_on_ctx_manager_usage = False
145
+ torch ._dynamo .config .allow_rnn = True
143
146
144
- # Use meta tensors as input to avoid actually running the model
145
- meta_input_dict = convert_real_to_meta (input_dict )
146
147
147
- def get_output_dtype (out ):
148
+ class GraphMetaExecutor :
149
+ def __init__ (self , device ):
150
+ self .device = device
151
+ self .op_stats = {}
152
+ self .is_complete = True
153
+ self .num_ops = 0
154
+ self .num_ops_misses_dtypes = 0
155
+ self .subgraph_counter = 0
156
+
157
+ def get_output_dtype (self , out ):
148
158
if isinstance (out , torch .Tensor ):
149
159
return out .dtype
150
160
if (
@@ -156,102 +166,160 @@ def get_output_dtype(out):
156
166
else :
157
167
return None
158
168
159
- is_complete = True
160
- op_stats = {}
161
- node_outputs = {}
162
- for node in traced .graph .nodes :
169
+ def get_op_name_and_func (self , node , node_outputs ):
163
170
op_name = None
164
- dtype = None
165
- if node .op == "placeholder" :
166
- node_outputs [node .name ] = meta_input_dict [node .target ]
167
- op_name = node .op
168
- elif node .op in ["call_function" , "call_module" , "call_method" ]:
169
- node_args = torch .fx .map_arg (
170
- node .args ,
171
- lambda n : node_outputs [n .name ] if isinstance (n , torch .fx .Node ) else n ,
172
- )
173
- node_kwargs = torch .fx .map_arg (
174
- node .kwargs ,
175
- lambda n : node_outputs [n .name ] if isinstance (n , torch .fx .Node ) else n ,
176
- )
177
-
178
- try :
179
- if node .op == "call_module" :
180
- # classname of module
181
- submod = traced .get_submodule (node .target )
182
- op_name = submod .__class__ .__name__
183
- op_func = submod
184
- elif node .op == "call_function" :
185
- op_name = node .target .__name__
186
- op_func = node .target
187
- elif node .op == "call_method" :
188
- op_name = node .target
189
- self_obj = (
190
- node_outputs [node .args [0 ].name ]
191
- if isinstance (node .args [0 ], torch .fx .Node )
192
- else node .args [0 ]
193
- )
194
- op_func = getattr (self_obj , node .target )
195
- node_args = node_args [1 :]
196
-
197
- # print(f"node.op={node.op}, op_name={op_name}, node.args={node.args}")
198
- if op_name == "_native_multi_head_attention" :
199
- out = resolve_native_multi_head_attention (* node_args , ** node_kwargs )
200
- elif op_name == "to" :
201
- # print(f"node.op={node.op}, op_name={op_name}, node.args={node.args}")
202
- out = resolve_tensor_to (
203
- node_outputs [node .args [0 ].name ], * node_args , ** node_kwargs
204
- )
205
- elif op_name == "item" :
206
- out = resolve_tensor_item (node_outputs [node .args [0 ].name ])
207
- else :
208
- out = op_func (* node_args , ** node_kwargs )
209
- node_outputs [node .name ] = out
210
- dtype = get_output_dtype (out )
211
- except Exception :
212
- out = resolve_with_real_tensor (op_func , device , node_args , node_kwargs )
213
- node_outputs [node .name ] = out
214
- if out is not None :
215
- dtype = get_output_dtype (out )
216
- else :
217
- print (
218
- f"dtype inference failed: node.op={ node .op } , op_name={ op_name } "
219
- )
220
- is_complete = False
221
- elif node .op == "get_attr" :
222
- op_name = node .op
223
- out = resolve_get_attr (traced , node )
224
- node_outputs [node .name ] = out
225
- dtype = get_output_dtype (out )
226
- elif node .op == "output" :
227
- op_name = node .op
228
- node_args = torch .fx .map_arg (
229
- node .args ,
230
- lambda n : node_outputs [n .name ] if isinstance (n , torch .fx .Node ) else n ,
231
- )
232
- node_outputs [node .name ] = node_args [0 ] if len (node_args ) == 1 else node_args
233
- dtype = get_output_dtype (node_args [0 ])
234
- else :
235
- assert False , f"node.op: { node .op } "
236
-
171
+ op_func = None
172
+ try :
173
+ if node .op == "call_module" :
174
+ # classname of module
175
+ submod = traced .get_submodule (node .target )
176
+ op_name = submod .__class__ .__name__
177
+ op_func = submod
178
+ elif node .op == "call_function" :
179
+ op_name = node .target .__name__
180
+ op_func = node .target
181
+ elif node .op == "call_method" :
182
+ op_name = node .target
183
+ self_obj = (
184
+ node_outputs [node .args [0 ].name ]
185
+ if isinstance (node .args [0 ], torch .fx .Node )
186
+ else node .args [0 ]
187
+ )
188
+ op_func = getattr (self_obj , node .target )
189
+ elif node .op in ["get_attr" , "placeholder" , "output" ]:
190
+ op_name = node .op
191
+ except Exception :
192
+ pass
193
+ return op_name , op_func
194
+
195
+ def update_op_stats (self , op_stats , op_name , op_dtype ):
237
196
if op_name is not None :
238
- dtype_str = str (dtype ).replace ("torch." , "" )
197
+ dtype_str = str (op_dtype ).replace ("torch." , "" )
239
198
if op_stats .get (op_name , None ) is None :
240
199
op_stats [op_name ] = OpStat (op_name , {dtype_str : 1 }, 1 )
241
200
else :
242
201
op_stats [op_name ].op_dtypes [dtype_str ] = (
243
202
op_stats [op_name ].op_dtypes .get (dtype_str , 0 ) + 1
244
203
)
245
- op_stats [op_name ].count = op_stats [op_name ].count + 1
246
- return is_complete , op_stats
204
+ op_stats [op_name ].count += 1
205
+
206
+ def __call__ (self , gm : torch .fx .GraphModule , sample_inputs ):
207
+ # Use meta tensors as input to avoid actually running the model
208
+ meta_sample_inputs = convert_real_to_meta (sample_inputs )
209
+
210
+ op_stats = {}
211
+ num_ops_misses_dtypes = 0
212
+
213
+ input_idx = 0
214
+ node_outputs = {}
215
+ for node in gm .graph .nodes :
216
+ out = None
217
+ op_dtype = None
218
+ op_name , op_func = self .get_op_name_and_func (node , node_outputs )
219
+ if node .op == "placeholder" :
220
+ out = meta_sample_inputs [input_idx ]
221
+ input_idx += 1
222
+ elif node .op in ["call_function" , "call_module" , "call_method" ]:
223
+ try :
224
+ node_args = torch .fx .map_arg (
225
+ node .args ,
226
+ lambda n : node_outputs [n .name ]
227
+ if isinstance (n , torch .fx .Node )
228
+ else n ,
229
+ )
230
+ node_kwargs = torch .fx .map_arg (
231
+ node .kwargs ,
232
+ lambda n : node_outputs [n .name ]
233
+ if isinstance (n , torch .fx .Node )
234
+ else n ,
235
+ )
236
+ if node .op == "call_method" :
237
+ node_args = node_args [1 :]
238
+
239
+ if op_name == "_native_multi_head_attention" :
240
+ out = resolve_native_multi_head_attention (
241
+ * node_args , ** node_kwargs
242
+ )
243
+ elif op_name == "to" :
244
+ out = resolve_tensor_to (
245
+ node_outputs [node .args [0 ].name ], * node_args , ** node_kwargs
246
+ )
247
+ elif op_name == "item" :
248
+ out = resolve_tensor_item (node_outputs [node .args [0 ].name ])
249
+ else :
250
+ assert op_func is not None , f"op_func of { node } is None."
251
+ out = op_func (* node_args , ** node_kwargs )
252
+ except Exception :
253
+ out = resolve_with_real_tensor (
254
+ op_func , self .device , node_args , node_kwargs
255
+ )
256
+ if out is None :
257
+ if num_ops_misses_dtypes == 0 :
258
+ print (
259
+ f"dtype inference failed: node.op={ node .op } , op_name={ op_name } "
260
+ )
261
+ num_ops_misses_dtypes += 1
262
+ elif node .op == "get_attr" :
263
+ out = resolve_get_attr (traced , node )
264
+ elif node .op == "output" :
265
+ pass
266
+ else :
267
+ assert False , f"node.op: { node .op } "
268
+
269
+ if out is not None :
270
+ node_outputs [node .name ] = out
271
+ op_dtype = self .get_output_dtype (out )
272
+
273
+ if node .op not in ["placeholder" , "output" ]:
274
+ self .update_op_stats (op_stats , op_name , op_dtype )
275
+
276
+ if num_ops_misses_dtypes > 0 :
277
+ self .is_complete = False
278
+ self .num_ops_misses_dtypes += num_ops_misses_dtypes
279
+ num_ops = 0
280
+ for name , stat in op_stats .items ():
281
+ num_ops += stat .count
282
+ if name in self .op_stats .keys ():
283
+ self .op_stats [name ].update (stat )
284
+ else :
285
+ self .op_stats [name ] = stat
286
+ self .num_ops += num_ops
287
+ self .subgraph_counter += 1
288
+ return gm .forward
289
+
290
+ def summary (self ):
291
+ print (
292
+ f"Totally { self .subgraph_counter } subgraphs, { self .num_ops } operators, and { self .num_ops_misses_dtypes } operators failed to inference dtypes."
293
+ )
247
294
248
295
249
- def collect_op_stats_with_make_fx (model , input_dict , arg_types ):
296
+ def collect_op_stats_with_compile (model , sample_inputs , device ):
297
+ assert isinstance (model , torch .nn .Module ), f"{ type (model )= } "
298
+ meta_executor = GraphMetaExecutor (device )
299
+ compiled_model = torch .compile (model , backend = meta_executor )
300
+ compiled_model (* sample_inputs )
301
+ meta_executor .summary ()
302
+ return "compile" , meta_executor .is_complete , meta_executor .op_stats
303
+
304
+
305
+ def collect_op_stats_manual (model , sample_inputs , device ):
306
+ try :
307
+ # FX symbolic trace
308
+ traced = torch .fx .symbolic_trace (model )
309
+ # print(traced.graph)
310
+ except Exception :
311
+ print ("Failed to FX symbolic_trace" )
312
+ return False , None
313
+
314
+ meta_executor = GraphMetaExecutor (device )
315
+ meta_executor (traced , sample_inputs )
316
+ meta_executor .summary ()
317
+ return meta_executor .is_complete , meta_executor .op_stats
318
+
319
+
320
+ def collect_op_stats_with_make_fx (model , sample_inputs ):
250
321
# Use meta tensors as input to avoid actually running the model
251
- meta_input_list = []
252
- for arg_name in arg_types .keys ():
253
- x = input_dict [arg_name ]
254
- meta_input_list .append (convert_real_to_meta (x ))
322
+ meta_input_list = convert_real_to_meta (sample_inputs )
255
323
256
324
try :
257
325
# Generate FX Graph, and automatically fill in meta information
@@ -325,14 +393,19 @@ def collect_model_stats(model_path, device, log_prompt):
325
393
model = model_class ()
326
394
arg_types = get_argument_types (model_class , "forward" )
327
395
input_dict = get_input_dict (model_path , device )
396
+ ordered_input_list = [input_dict [arg_name ] for arg_name in arg_types .keys ()]
328
397
329
398
num_ops = 0
330
399
num_outputs = 0
331
400
ops_count_dict = {}
332
401
op_dtypes = {}
333
- method , is_complete , op_stats = collect_op_stats (
334
- model , input_dict , arg_types , device
402
+ method , is_complete , op_stats = collect_op_stats_with_compile (
403
+ model , ordered_input_list , device
335
404
)
405
+
406
+ # method, is_complete, op_stats = collect_op_stats(
407
+ # model, input_dict, arg_types, device
408
+ # )
336
409
if op_stats is not None :
337
410
for op_name , stat in sorted (op_stats .items ()):
338
411
if op_name == "placeholder" :
@@ -474,5 +547,4 @@ def main(args):
474
547
help = "Log prompt for stats log filtering." ,
475
548
)
476
549
args = parser .parse_args ()
477
- print (f"[CollectStats Arguments] { args } " )
478
550
main (args = args )
0 commit comments