Skip to content

Commit 1415926

Browse files
committed
Fix support of call_method.
1 parent bcf9d5a commit 1415926

File tree

1 file changed

+46
-35
lines changed

1 file changed

+46
-35
lines changed

graph_net/torch/collect_stats.py

Lines changed: 46 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -60,21 +60,26 @@ def resolve_get_attr(gm: torch.fx.GraphModule, node: torch.fx.Node):
6060
val = gm
6161
for a in attr_itr:
6262
val = getattr(val, a)
63-
return val
63+
out = val.to(device="meta") if isinstance(val, torch.Tensor) else val
64+
return out
6465

6566

6667
def collect_op_stats(model, input_dict):
68+
# FX symbolic trace
69+
try:
70+
traced = torch.fx.symbolic_trace(model)
71+
# print(traced.graph)
72+
except Exception:
73+
print("Failed to FX symbolic trace")
74+
return None
75+
6776
# Use meta tensors as input to avoid actually running the model
6877
meta_input_dict = {}
6978
for name, x in input_dict.items():
7079
meta_input_dict[name] = (
7180
torch.empty_like(x, device="meta") if isinstance(x, torch.Tensor) else x
7281
)
7382

74-
# FX symbolic trace
75-
traced = torch.fx.symbolic_trace(model)
76-
# print(traced.graph)
77-
7883
node_outputs = {}
7984
op_stats = {}
8085
for node in traced.graph.nodes:
@@ -84,7 +89,7 @@ def collect_op_stats(model, input_dict):
8489
node_outputs[node.name] = meta_input_dict[node.target]
8590
op_name = node.op
8691
dtype = node_outputs[node.name].dtype
87-
elif node.op in ["call_function", "call_method", "call_module"]:
92+
elif node.op in ["call_function", "call_module", "call_method"]:
8893
node_args = torch.fx.map_arg(
8994
node.args,
9095
lambda n: node_outputs[n.name] if isinstance(n, torch.fx.Node) else n,
@@ -96,28 +101,32 @@ def collect_op_stats(model, input_dict):
96101

97102
if node.op == "call_module":
98103
# classname of module
99-
submod = dict(traced.named_modules())[node.target]
104+
submod = traced.get_submodule(node.target)
100105
op_name = submod.__class__.__name__
101-
try:
102-
out = submod(*node_args, **node_kwargs)
103-
node_outputs[node.name] = out
104-
dtype = out.dtype if isinstance(out, torch.Tensor) else None
105-
except Exception:
106-
node_outputs[node.name] = None
107-
elif node.op in ["call_function", "call_method"]:
108-
op_name = (
109-
node.target.__name__ if node.op == "call_function" else node.target
106+
op_func = submod
107+
elif node.op == "call_function":
108+
op_name = node.target.__name__
109+
op_func = node.target
110+
elif node.op == "call_method":
111+
op_name = node.target
112+
self_obj = (
113+
node_outputs[node.args[0].name]
114+
if isinstance(node.args[0], torch.fx.Node)
115+
else node.args[0]
110116
)
111-
try:
112-
out = node.target(*node_args, **node_kwargs)
113-
node_outputs[node.name] = out
114-
dtype = out.dtype if isinstance(out, torch.Tensor) else None
115-
except Exception:
116-
print(f"dtype inference failed: op_name={op_name}")
117-
node_outputs[node.name] = None
117+
op_func = getattr(self_obj, node.target)
118+
node_args = node_args[1:]
119+
120+
try:
121+
out = op_func(*node_args, **node_kwargs)
122+
node_outputs[node.name] = out
123+
dtype = out.dtype if isinstance(out, torch.Tensor) else None
124+
except Exception:
125+
print(f"dtype inference failed: node.op={node.op}, op_name={op_name}")
126+
node_outputs[node.name] = None
118127
elif node.op == "get_attr":
119-
val = resolve_get_attr(traced, node)
120-
out = val.to(device="meta") if isinstance(val, torch.Tensor) else val
128+
op_name = node.op
129+
out = resolve_get_attr(traced, node)
121130
node_outputs[node.name] = out
122131
dtype = out.dtype if isinstance(out, torch.Tensor) else None
123132
elif node.op == "output":
@@ -156,18 +165,20 @@ def collect_model_stats(model_path, device, log_prompt):
156165
num_outputs = 0
157166
dtypes = set()
158167
op_stats = collect_op_stats(model, input_dict)
159-
for op_name, stat in op_stats.items():
160-
if op_name == "placeholder":
161-
num_inputs += stat.count
162-
elif op_name == "output":
163-
num_outputs += stat.count
164-
else:
165-
num_ops += stat.count
166-
for v in stat.dtype:
167-
if v is not None:
168-
dtypes.add(v)
168+
if op_stats is not None:
169+
for op_name, stat in op_stats.items():
170+
if op_name == "placeholder":
171+
num_inputs += stat.count
172+
elif op_name == "output":
173+
num_outputs += stat.count
174+
else:
175+
num_ops += stat.count
176+
for v in stat.dtype:
177+
if v is not None:
178+
dtypes.add(v)
169179

170180
arg_types = get_argument_types(model_class, "forward")
181+
num_inputs = len(arg_types) if op_stats is None else num_inputs
171182
num_params = 0
172183
param_dtypes = set()
173184
for name, arg_type in arg_types.items():

0 commit comments

Comments
 (0)