@@ -60,21 +60,26 @@ def resolve_get_attr(gm: torch.fx.GraphModule, node: torch.fx.Node):
60
60
val = gm
61
61
for a in attr_itr :
62
62
val = getattr (val , a )
63
- return val
63
+ out = val .to (device = "meta" ) if isinstance (val , torch .Tensor ) else val
64
+ return out
64
65
65
66
66
67
def collect_op_stats (model , input_dict ):
68
+ # FX symbolic trace
69
+ try :
70
+ traced = torch .fx .symbolic_trace (model )
71
+ # print(traced.graph)
72
+ except Exception :
73
+ print ("Failed to FX symbolic trace" )
74
+ return None
75
+
67
76
# Use meta tensors as input to avoid actually running the model
68
77
meta_input_dict = {}
69
78
for name , x in input_dict .items ():
70
79
meta_input_dict [name ] = (
71
80
torch .empty_like (x , device = "meta" ) if isinstance (x , torch .Tensor ) else x
72
81
)
73
82
74
- # FX symbolic trace
75
- traced = torch .fx .symbolic_trace (model )
76
- # print(traced.graph)
77
-
78
83
node_outputs = {}
79
84
op_stats = {}
80
85
for node in traced .graph .nodes :
@@ -84,7 +89,7 @@ def collect_op_stats(model, input_dict):
84
89
node_outputs [node .name ] = meta_input_dict [node .target ]
85
90
op_name = node .op
86
91
dtype = node_outputs [node .name ].dtype
87
- elif node .op in ["call_function" , "call_method " , "call_module " ]:
92
+ elif node .op in ["call_function" , "call_module " , "call_method " ]:
88
93
node_args = torch .fx .map_arg (
89
94
node .args ,
90
95
lambda n : node_outputs [n .name ] if isinstance (n , torch .fx .Node ) else n ,
@@ -96,28 +101,32 @@ def collect_op_stats(model, input_dict):
96
101
97
102
if node .op == "call_module" :
98
103
# classname of module
99
- submod = dict ( traced .named_modules ())[ node .target ]
104
+ submod = traced .get_submodule ( node .target )
100
105
op_name = submod .__class__ .__name__
101
- try :
102
- out = submod (* node_args , ** node_kwargs )
103
- node_outputs [node .name ] = out
104
- dtype = out .dtype if isinstance (out , torch .Tensor ) else None
105
- except Exception :
106
- node_outputs [node .name ] = None
107
- elif node .op in ["call_function" , "call_method" ]:
108
- op_name = (
109
- node .target .__name__ if node .op == "call_function" else node .target
106
+ op_func = submod
107
+ elif node .op == "call_function" :
108
+ op_name = node .target .__name__
109
+ op_func = node .target
110
+ elif node .op == "call_method" :
111
+ op_name = node .target
112
+ self_obj = (
113
+ node_outputs [node .args [0 ].name ]
114
+ if isinstance (node .args [0 ], torch .fx .Node )
115
+ else node .args [0 ]
110
116
)
111
- try :
112
- out = node .target (* node_args , ** node_kwargs )
113
- node_outputs [node .name ] = out
114
- dtype = out .dtype if isinstance (out , torch .Tensor ) else None
115
- except Exception :
116
- print (f"dtype inference failed: op_name={ op_name } " )
117
- node_outputs [node .name ] = None
117
+ op_func = getattr (self_obj , node .target )
118
+ node_args = node_args [1 :]
119
+
120
+ try :
121
+ out = op_func (* node_args , ** node_kwargs )
122
+ node_outputs [node .name ] = out
123
+ dtype = out .dtype if isinstance (out , torch .Tensor ) else None
124
+ except Exception :
125
+ print (f"dtype inference failed: node.op={ node .op } , op_name={ op_name } " )
126
+ node_outputs [node .name ] = None
118
127
elif node .op == "get_attr" :
119
- val = resolve_get_attr ( traced , node )
120
- out = val . to ( device = "meta" ) if isinstance ( val , torch . Tensor ) else val
128
+ op_name = node . op
129
+ out = resolve_get_attr ( traced , node )
121
130
node_outputs [node .name ] = out
122
131
dtype = out .dtype if isinstance (out , torch .Tensor ) else None
123
132
elif node .op == "output" :
@@ -156,18 +165,20 @@ def collect_model_stats(model_path, device, log_prompt):
156
165
num_outputs = 0
157
166
dtypes = set ()
158
167
op_stats = collect_op_stats (model , input_dict )
159
- for op_name , stat in op_stats .items ():
160
- if op_name == "placeholder" :
161
- num_inputs += stat .count
162
- elif op_name == "output" :
163
- num_outputs += stat .count
164
- else :
165
- num_ops += stat .count
166
- for v in stat .dtype :
167
- if v is not None :
168
- dtypes .add (v )
168
+ if op_stats is not None :
169
+ for op_name , stat in op_stats .items ():
170
+ if op_name == "placeholder" :
171
+ num_inputs += stat .count
172
+ elif op_name == "output" :
173
+ num_outputs += stat .count
174
+ else :
175
+ num_ops += stat .count
176
+ for v in stat .dtype :
177
+ if v is not None :
178
+ dtypes .add (v )
169
179
170
180
arg_types = get_argument_types (model_class , "forward" )
181
+ num_inputs = len (arg_types ) if op_stats is None else num_inputs
171
182
num_params = 0
172
183
param_dtypes = set ()
173
184
for name , arg_type in arg_types .items ():
0 commit comments