Skip to content

Commit bcf9d5a

Browse files
committed
Add support of get_attr and simplify some codes.
1 parent 7a91e19 commit bcf9d5a

File tree

1 file changed

+25
-13
lines changed

1 file changed

+25
-13
lines changed

graph_net/torch/collect_stats.py

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,14 @@ class OpStat:
5555
count: int = 0
5656

5757

58+
def resolve_get_attr(gm: torch.fx.GraphModule, node: torch.fx.Node):
59+
attr_itr = node.target.split(".")
60+
val = gm
61+
for a in attr_itr:
62+
val = getattr(val, a)
63+
return val
64+
65+
5866
def collect_op_stats(model, input_dict):
5967
# Use meta tensors as input to avoid actually running the model
6068
meta_input_dict = {}
@@ -77,14 +85,14 @@ def collect_op_stats(model, input_dict):
7785
op_name = node.op
7886
dtype = node_outputs[node.name].dtype
7987
elif node.op in ["call_function", "call_method", "call_module"]:
80-
node_args = []
81-
for arg in node.args:
82-
node_args.append(
83-
node_outputs[arg.name] if hasattr(arg, "name") else arg
84-
)
85-
node_kwargs = {}
86-
for k, v in node.kwargs.items():
87-
node_kwargs[k] = node_outputs[v.name] if hasattr(v, "name") else v
88+
node_args = torch.fx.map_arg(
89+
node.args,
90+
lambda n: node_outputs[n.name] if isinstance(n, torch.fx.Node) else n,
91+
)
92+
node_kwargs = torch.fx.map_arg(
93+
node.kwargs,
94+
lambda n: node_outputs[n.name] if isinstance(n, torch.fx.Node) else n,
95+
)
8896

8997
if node.op == "call_module":
9098
# classname of module
@@ -107,13 +115,17 @@ def collect_op_stats(model, input_dict):
107115
except Exception:
108116
print(f"dtype inference failed: op_name={op_name}")
109117
node_outputs[node.name] = None
118+
elif node.op == "get_attr":
119+
val = resolve_get_attr(traced, node)
120+
out = val.to(device="meta") if isinstance(val, torch.Tensor) else val
121+
node_outputs[node.name] = out
122+
dtype = out.dtype if isinstance(out, torch.Tensor) else None
110123
elif node.op == "output":
111124
op_name = node.op
112-
node_args = []
113-
for arg in node.args:
114-
node_args.append(
115-
node_outputs[arg.name] if hasattr(arg, "name") else arg
116-
)
125+
node_args = torch.fx.map_arg(
126+
node.args,
127+
lambda n: node_outputs[n.name] if isinstance(n, torch.fx.Node) else n,
128+
)
117129
node_outputs[node.name] = node_args[0] if len(node_args) == 1 else node_args
118130
dtype = (
119131
node_args[0].dtype if isinstance(node_args[0], torch.Tensor) else None

0 commit comments

Comments
 (0)