@@ -79,7 +79,7 @@ def resolve_native_multi_head_attention(*args, **kwargs):
7979
8080
8181def resolve_tensor_to (tensor , * args , ** kwargs ):
82- if isinstance (args [0 ], torch .dtype ):
82+ if len ( args ) > 0 and isinstance (args [0 ], torch .dtype ):
8383 dtype = args [0 ]
8484 else :
8585 dtype = tensor .dtype
@@ -99,7 +99,40 @@ def resolve_get_attr(gm: torch.fx.GraphModule, node: torch.fx.Node):
9999 return out
100100
101101
102- def collect_op_stats_manual (model , input_dict ):
102+ def convert_real_to_meta (x ):
103+ if isinstance (x , torch .Tensor ) and not x .is_meta :
104+ return torch .empty_like (x , device = "meta" )
105+ elif isinstance (x , (list , tuple )):
106+ return type (x )(convert_real_to_meta (v ) for v in x )
107+ elif isinstance (x , dict ):
108+ return {k : convert_real_to_meta (v ) for k , v in x .items ()}
109+ else :
110+ return x
111+
112+
113+ def convert_meta_to_real (x , device ):
114+ if isinstance (x , torch .Tensor ) and x .is_meta :
115+ return torch .empty_like (x , device = device )
116+ elif isinstance (x , (list , tuple )):
117+ return type (x )(convert_meta_to_real (v , device ) for v in x )
118+ elif isinstance (x , dict ):
119+ return {k : convert_meta_to_real (v , device ) for k , v in x .items ()}
120+ else :
121+ return x
122+
123+
124+ def resolve_with_real_tensor (op_func , device , meta_args , meta_kwargs ):
125+ try :
126+ real_args = convert_meta_to_real (meta_args , device )
127+ real_kwargs = convert_meta_to_real (meta_kwargs , device )
128+
129+ real_out = op_func (* real_args , ** real_kwargs )
130+ return convert_real_to_meta (real_out )
131+ except Exception :
132+ return None
133+
134+
135+ def collect_op_stats_manual (model , input_dict , device ):
103136 try :
104137 # FX symbolic trace
105138 traced = torch .fx .symbolic_trace (model )
@@ -109,11 +142,19 @@ def collect_op_stats_manual(model, input_dict):
109142 return False , None
110143
111144 # Use meta tensors as input to avoid actually running the model
112- meta_input_dict = {}
113- for name , x in input_dict .items ():
114- meta_input_dict [name ] = (
115- torch .empty_like (x , device = "meta" ) if isinstance (x , torch .Tensor ) else x
116- )
145+ meta_input_dict = convert_real_to_meta (input_dict )
146+
147+ def get_output_dtype (out ):
148+ if isinstance (out , torch .Tensor ):
149+ return out .dtype
150+ if (
151+ isinstance (out , (list , tuple ))
152+ and len (out ) > 0
153+ and isinstance (out [0 ], torch .Tensor )
154+ ):
155+ return out [0 ].dtype
156+ else :
157+ return None
117158
118159 is_complete = True
119160 op_stats = {}
@@ -157,6 +198,7 @@ def collect_op_stats_manual(model, input_dict):
157198 if op_name == "_native_multi_head_attention" :
158199 out = resolve_native_multi_head_attention (* node_args , ** node_kwargs )
159200 elif op_name == "to" :
201+ # print(f"node.op={node.op}, op_name={op_name}, node.args={node.args}")
160202 out = resolve_tensor_to (
161203 node_outputs [node .args [0 ].name ], * node_args , ** node_kwargs
162204 )
@@ -165,26 +207,30 @@ def collect_op_stats_manual(model, input_dict):
165207 else :
166208 out = op_func (* node_args , ** node_kwargs )
167209 node_outputs [node .name ] = out
168- dtype = out . dtype if isinstance (out , torch . Tensor ) else None
210+ dtype = get_output_dtype (out )
169211 except Exception :
170- print (f"dtype inference failed: node.op={ node .op } , op_name={ op_name } " )
171- node_outputs [node .name ] = None
172- is_complete = False
212+ out = resolve_with_real_tensor (op_func , device , node_args , node_kwargs )
213+ node_outputs [node .name ] = out
214+ if out is not None :
215+ dtype = get_output_dtype (out )
216+ else :
217+ print (
218+ f"dtype inference failed: node.op={ node .op } , op_name={ op_name } "
219+ )
220+ is_complete = False
173221 elif node .op == "get_attr" :
174222 op_name = node .op
175223 out = resolve_get_attr (traced , node )
176224 node_outputs [node .name ] = out
177- dtype = out . dtype if isinstance (out , torch . Tensor ) else None
225+ dtype = get_output_dtype (out )
178226 elif node .op == "output" :
179227 op_name = node .op
180228 node_args = torch .fx .map_arg (
181229 node .args ,
182230 lambda n : node_outputs [n .name ] if isinstance (n , torch .fx .Node ) else n ,
183231 )
184232 node_outputs [node .name ] = node_args [0 ] if len (node_args ) == 1 else node_args
185- dtype = (
186- node_args [0 ].dtype if isinstance (node_args [0 ], torch .Tensor ) else None
187- )
233+ dtype = get_output_dtype (node_args [0 ])
188234 else :
189235 assert False , f"node.op: { node .op } "
190236
@@ -205,10 +251,7 @@ def collect_op_stats_with_make_fx(model, input_dict, arg_types):
205251 meta_input_list = []
206252 for arg_name in arg_types .keys ():
207253 x = input_dict [arg_name ]
208- meta_x = (
209- torch .empty_like (x , device = "meta" ) if isinstance (x , torch .Tensor ) else x
210- )
211- meta_input_list .append (meta_x )
254+ meta_input_list .append (convert_real_to_meta (x ))
212255
213256 try :
214257 # Generate FX Graph, and automatically fill in meta information
@@ -262,8 +305,10 @@ def collect_op_stats_with_make_fx(model, input_dict, arg_types):
262305 return is_complete , op_stats
263306
264307
265- def collect_op_stats (model , input_dict , arg_types ):
266- is_complete_manual , op_stats_manual = collect_op_stats_manual (model , input_dict )
308+ def collect_op_stats (model , input_dict , arg_types , device ):
309+ is_complete_manual , op_stats_manual = collect_op_stats_manual (
310+ model , input_dict , device
311+ )
267312 if not is_complete_manual :
268313 is_complete_make_fx , op_stats_make_fx = collect_op_stats_with_make_fx (
269314 model , input_dict , arg_types
@@ -285,7 +330,9 @@ def collect_model_stats(model_path, device, log_prompt):
285330 num_outputs = 0
286331 ops_count_dict = {}
287332 op_dtypes = {}
288- method , is_complete , op_stats = collect_op_stats (model , input_dict , arg_types )
333+ method , is_complete , op_stats = collect_op_stats (
334+ model , input_dict , arg_types , device
335+ )
289336 if op_stats is not None :
290337 for op_name , stat in sorted (op_stats .items ()):
291338 if op_name == "placeholder" :
0 commit comments