@@ -55,6 +55,14 @@ class OpStat:
55
55
count : int = 0
56
56
57
57
58
+ def resolve_get_attr (gm : torch .fx .GraphModule , node : torch .fx .Node ):
59
+ attr_itr = node .target .split ("." )
60
+ val = gm
61
+ for a in attr_itr :
62
+ val = getattr (val , a )
63
+ return val
64
+
65
+
58
66
def collect_op_stats (model , input_dict ):
59
67
# Use meta tensors as input to avoid actually running the model
60
68
meta_input_dict = {}
@@ -77,14 +85,14 @@ def collect_op_stats(model, input_dict):
77
85
op_name = node .op
78
86
dtype = node_outputs [node .name ].dtype
79
87
elif node .op in ["call_function" , "call_method" , "call_module" ]:
80
- node_args = []
81
- for arg in node .args :
82
- node_args . append (
83
- node_outputs [ arg . name ] if hasattr ( arg , "name" ) else arg
84
- )
85
- node_kwargs = {}
86
- for k , v in node . kwargs . items ():
87
- node_kwargs [ k ] = node_outputs [ v . name ] if hasattr ( v , "name" ) else v
88
+ node_args = torch . fx . map_arg (
89
+ node .args ,
90
+ lambda n : node_outputs [ n . name ] if isinstance ( n , torch . fx . Node ) else n ,
91
+ )
92
+ node_kwargs = torch . fx . map_arg (
93
+ node . kwargs ,
94
+ lambda n : node_outputs [ n . name ] if isinstance ( n , torch . fx . Node ) else n ,
95
+ )
88
96
89
97
if node .op == "call_module" :
90
98
# classname of module
@@ -107,13 +115,17 @@ def collect_op_stats(model, input_dict):
107
115
except Exception :
108
116
print (f"dtype inference failed: op_name={ op_name } " )
109
117
node_outputs [node .name ] = None
118
+ elif node .op == "get_attr" :
119
+ val = resolve_get_attr (traced , node )
120
+ out = val .to (device = "meta" ) if isinstance (val , torch .Tensor ) else val
121
+ node_outputs [node .name ] = out
122
+ dtype = out .dtype if isinstance (out , torch .Tensor ) else None
110
123
elif node .op == "output" :
111
124
op_name = node .op
112
- node_args = []
113
- for arg in node .args :
114
- node_args .append (
115
- node_outputs [arg .name ] if hasattr (arg , "name" ) else arg
116
- )
125
+ node_args = torch .fx .map_arg (
126
+ node .args ,
127
+ lambda n : node_outputs [n .name ] if isinstance (n , torch .fx .Node ) else n ,
128
+ )
117
129
node_outputs [node .name ] = node_args [0 ] if len (node_args ) == 1 else node_args
118
130
dtype = (
119
131
node_args [0 ].dtype if isinstance (node_args [0 ], torch .Tensor ) else None
0 commit comments