@@ -79,7 +79,7 @@ def resolve_native_multi_head_attention(*args, **kwargs):
79
79
80
80
81
81
def resolve_tensor_to (tensor , * args , ** kwargs ):
82
- if isinstance (args [0 ], torch .dtype ):
82
+ if len ( args ) > 0 and isinstance (args [0 ], torch .dtype ):
83
83
dtype = args [0 ]
84
84
else :
85
85
dtype = tensor .dtype
@@ -99,7 +99,40 @@ def resolve_get_attr(gm: torch.fx.GraphModule, node: torch.fx.Node):
99
99
return out
100
100
101
101
102
- def collect_op_stats_manual (model , input_dict ):
102
+ def convert_real_to_meta (x ):
103
+ if isinstance (x , torch .Tensor ) and not x .is_meta :
104
+ return torch .empty_like (x , device = "meta" )
105
+ elif isinstance (x , (list , tuple )):
106
+ return type (x )(convert_real_to_meta (v ) for v in x )
107
+ elif isinstance (x , dict ):
108
+ return {k : convert_real_to_meta (v ) for k , v in x .items ()}
109
+ else :
110
+ return x
111
+
112
+
113
+ def convert_meta_to_real (x , device ):
114
+ if isinstance (x , torch .Tensor ) and x .is_meta :
115
+ return torch .empty_like (x , device = device )
116
+ elif isinstance (x , (list , tuple )):
117
+ return type (x )(convert_meta_to_real (v , device ) for v in x )
118
+ elif isinstance (x , dict ):
119
+ return {k : convert_meta_to_real (v , device ) for k , v in x .items ()}
120
+ else :
121
+ return x
122
+
123
+
124
+ def resolve_with_real_tensor (op_func , device , meta_args , meta_kwargs ):
125
+ try :
126
+ real_args = convert_meta_to_real (meta_args , device )
127
+ real_kwargs = convert_meta_to_real (meta_kwargs , device )
128
+
129
+ real_out = op_func (* real_args , ** real_kwargs )
130
+ return convert_real_to_meta (real_out )
131
+ except Exception :
132
+ return None
133
+
134
+
135
+ def collect_op_stats_manual (model , input_dict , device ):
103
136
try :
104
137
# FX symbolic trace
105
138
traced = torch .fx .symbolic_trace (model )
@@ -109,11 +142,19 @@ def collect_op_stats_manual(model, input_dict):
109
142
return False , None
110
143
111
144
# Use meta tensors as input to avoid actually running the model
112
- meta_input_dict = {}
113
- for name , x in input_dict .items ():
114
- meta_input_dict [name ] = (
115
- torch .empty_like (x , device = "meta" ) if isinstance (x , torch .Tensor ) else x
116
- )
145
+ meta_input_dict = convert_real_to_meta (input_dict )
146
+
147
+ def get_output_dtype (out ):
148
+ if isinstance (out , torch .Tensor ):
149
+ return out .dtype
150
+ if (
151
+ isinstance (out , (list , tuple ))
152
+ and len (out ) > 0
153
+ and isinstance (out [0 ], torch .Tensor )
154
+ ):
155
+ return out [0 ].dtype
156
+ else :
157
+ return None
117
158
118
159
is_complete = True
119
160
op_stats = {}
@@ -157,6 +198,7 @@ def collect_op_stats_manual(model, input_dict):
157
198
if op_name == "_native_multi_head_attention" :
158
199
out = resolve_native_multi_head_attention (* node_args , ** node_kwargs )
159
200
elif op_name == "to" :
201
+ # print(f"node.op={node.op}, op_name={op_name}, node.args={node.args}")
160
202
out = resolve_tensor_to (
161
203
node_outputs [node .args [0 ].name ], * node_args , ** node_kwargs
162
204
)
@@ -165,26 +207,30 @@ def collect_op_stats_manual(model, input_dict):
165
207
else :
166
208
out = op_func (* node_args , ** node_kwargs )
167
209
node_outputs [node .name ] = out
168
- dtype = out . dtype if isinstance (out , torch . Tensor ) else None
210
+ dtype = get_output_dtype (out )
169
211
except Exception :
170
- print (f"dtype inference failed: node.op={ node .op } , op_name={ op_name } " )
171
- node_outputs [node .name ] = None
172
- is_complete = False
212
+ out = resolve_with_real_tensor (op_func , device , node_args , node_kwargs )
213
+ node_outputs [node .name ] = out
214
+ if out is not None :
215
+ dtype = get_output_dtype (out )
216
+ else :
217
+ print (
218
+ f"dtype inference failed: node.op={ node .op } , op_name={ op_name } "
219
+ )
220
+ is_complete = False
173
221
elif node .op == "get_attr" :
174
222
op_name = node .op
175
223
out = resolve_get_attr (traced , node )
176
224
node_outputs [node .name ] = out
177
- dtype = out . dtype if isinstance (out , torch . Tensor ) else None
225
+ dtype = get_output_dtype (out )
178
226
elif node .op == "output" :
179
227
op_name = node .op
180
228
node_args = torch .fx .map_arg (
181
229
node .args ,
182
230
lambda n : node_outputs [n .name ] if isinstance (n , torch .fx .Node ) else n ,
183
231
)
184
232
node_outputs [node .name ] = node_args [0 ] if len (node_args ) == 1 else node_args
185
- dtype = (
186
- node_args [0 ].dtype if isinstance (node_args [0 ], torch .Tensor ) else None
187
- )
233
+ dtype = get_output_dtype (node_args [0 ])
188
234
else :
189
235
assert False , f"node.op: { node .op } "
190
236
@@ -205,10 +251,7 @@ def collect_op_stats_with_make_fx(model, input_dict, arg_types):
205
251
meta_input_list = []
206
252
for arg_name in arg_types .keys ():
207
253
x = input_dict [arg_name ]
208
- meta_x = (
209
- torch .empty_like (x , device = "meta" ) if isinstance (x , torch .Tensor ) else x
210
- )
211
- meta_input_list .append (meta_x )
254
+ meta_input_list .append (convert_real_to_meta (x ))
212
255
213
256
try :
214
257
# Generate FX Graph, and automatically fill in meta information
@@ -262,8 +305,10 @@ def collect_op_stats_with_make_fx(model, input_dict, arg_types):
262
305
return is_complete , op_stats
263
306
264
307
265
- def collect_op_stats (model , input_dict , arg_types ):
266
- is_complete_manual , op_stats_manual = collect_op_stats_manual (model , input_dict )
308
+ def collect_op_stats (model , input_dict , arg_types , device ):
309
+ is_complete_manual , op_stats_manual = collect_op_stats_manual (
310
+ model , input_dict , device
311
+ )
267
312
if not is_complete_manual :
268
313
is_complete_make_fx , op_stats_make_fx = collect_op_stats_with_make_fx (
269
314
model , input_dict , arg_types
@@ -285,7 +330,9 @@ def collect_model_stats(model_path, device, log_prompt):
285
330
num_outputs = 0
286
331
ops_count_dict = {}
287
332
op_dtypes = {}
288
- method , is_complete , op_stats = collect_op_stats (model , input_dict , arg_types )
333
+ method , is_complete , op_stats = collect_op_stats (
334
+ model , input_dict , arg_types , device
335
+ )
289
336
if op_stats is not None :
290
337
for op_name , stat in sorted (op_stats .items ()):
291
338
if op_name == "placeholder" :
0 commit comments