Skip to content

Commit 7a91e19

Browse files
committed
Implement a function to collect the model's execution stats.
1 parent c1bc381 commit 7a91e19

File tree

2 files changed

+229
-1
lines changed

2 files changed

+229
-1
lines changed

graph_net/torch/collect_stats.py

Lines changed: 227 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,227 @@
1+
import argparse
2+
import os
3+
import sys
4+
import math
5+
import importlib
6+
import inspect
7+
from typing import Type
8+
from dataclasses import dataclass, field
9+
from collections import defaultdict
10+
11+
import torch
12+
from torch.fx.passes.shape_prop import ShapeProp
13+
from graph_net.torch import utils
14+
15+
16+
def is_single_model_dir(model_dir):
17+
return os.path.isfile(f"{model_dir}/graph_net.json")
18+
19+
20+
def load_class_from_file(file_path: str, class_name: str) -> Type[torch.nn.Module]:
21+
spec = importlib.util.spec_from_file_location("unnamed", file_path)
22+
unnamed = importlib.util.module_from_spec(spec)
23+
spec.loader.exec_module(unnamed)
24+
model_class = getattr(unnamed, class_name, None)
25+
return model_class
26+
27+
28+
def get_argument_types(model_class, func_name):
29+
arg_types = {}
30+
for name, func in inspect.getmembers(model_class, predicate=inspect.isfunction):
31+
if name == func_name:
32+
for arg_name, arg in inspect.signature(func).parameters.items():
33+
if arg_name != "self":
34+
arg_types[arg_name] = (
35+
None if arg.annotation is inspect._empty else arg.annotation
36+
)
37+
return arg_types
38+
39+
40+
def get_input_dict(model_path, device):
41+
inputs_params = utils.load_converted_from_text(f"{model_path}")
42+
params = inputs_params["weight_info"]
43+
for tensor_meta in params.values():
44+
if hasattr(tensor_meta, "device"):
45+
tensor_meta.device = device
46+
return {
47+
k: utils.replay_tensor(v).to(torch.device(device)) for k, v in params.items()
48+
}
49+
50+
51+
@dataclass
52+
class OpStat:
53+
op_name: str
54+
dtype: set[str] = field(default_factory=set)
55+
count: int = 0
56+
57+
58+
def collect_op_stats(model, input_dict):
59+
# Use meta tensors as input to avoid actually running the model
60+
meta_input_dict = {}
61+
for name, x in input_dict.items():
62+
meta_input_dict[name] = (
63+
torch.empty_like(x, device="meta") if isinstance(x, torch.Tensor) else x
64+
)
65+
66+
# FX symbolic trace
67+
traced = torch.fx.symbolic_trace(model)
68+
# print(traced.graph)
69+
70+
node_outputs = {}
71+
op_stats = {}
72+
for node in traced.graph.nodes:
73+
op_name = None
74+
dtype = None
75+
if node.op == "placeholder":
76+
node_outputs[node.name] = meta_input_dict[node.target]
77+
op_name = node.op
78+
dtype = node_outputs[node.name].dtype
79+
elif node.op in ["call_function", "call_method", "call_module"]:
80+
node_args = []
81+
for arg in node.args:
82+
node_args.append(
83+
node_outputs[arg.name] if hasattr(arg, "name") else arg
84+
)
85+
node_kwargs = {}
86+
for k, v in node.kwargs.items():
87+
node_kwargs[k] = node_outputs[v.name] if hasattr(v, "name") else v
88+
89+
if node.op == "call_module":
90+
# classname of module
91+
submod = dict(traced.named_modules())[node.target]
92+
op_name = submod.__class__.__name__
93+
try:
94+
out = submod(*node_args, **node_kwargs)
95+
node_outputs[node.name] = out
96+
dtype = out.dtype if isinstance(out, torch.Tensor) else None
97+
except Exception:
98+
node_outputs[node.name] = None
99+
elif node.op in ["call_function", "call_method"]:
100+
op_name = (
101+
node.target.__name__ if node.op == "call_function" else node.target
102+
)
103+
try:
104+
out = node.target(*node_args, **node_kwargs)
105+
node_outputs[node.name] = out
106+
dtype = out.dtype if isinstance(out, torch.Tensor) else None
107+
except Exception:
108+
print(f"dtype inference failed: op_name={op_name}")
109+
node_outputs[node.name] = None
110+
elif node.op == "output":
111+
op_name = node.op
112+
node_args = []
113+
for arg in node.args:
114+
node_args.append(
115+
node_outputs[arg.name] if hasattr(arg, "name") else arg
116+
)
117+
node_outputs[node.name] = node_args[0] if len(node_args) == 1 else node_args
118+
dtype = (
119+
node_args[0].dtype if isinstance(node_args[0], torch.Tensor) else None
120+
)
121+
else:
122+
assert False, f"node.op: {node.op}"
123+
124+
if op_name is not None:
125+
dtype_str = str(dtype).replace("torch.", "") if dtype is not None else None
126+
if op_stats.get(op_name, None) is None:
127+
op_stats[op_name] = OpStat(op_name, {dtype_str}, 1)
128+
else:
129+
op_stats[op_name].dtype.add(dtype_str)
130+
op_stats[op_name].count = op_stats[op_name].count + 1
131+
return op_stats
132+
133+
134+
def collect_model_stats(model_path, device, log_prompt):
135+
print(f"Collect information for {model_path}")
136+
model_class = load_class_from_file(
137+
os.path.join(model_path, "model.py"), "GraphModule"
138+
)
139+
model = model_class()
140+
input_dict = get_input_dict(model_path, device)
141+
142+
num_ops = 0
143+
num_inputs = 0
144+
num_outputs = 0
145+
dtypes = set()
146+
op_stats = collect_op_stats(model, input_dict)
147+
for op_name, stat in op_stats.items():
148+
if op_name == "placeholder":
149+
num_inputs += stat.count
150+
elif op_name == "output":
151+
num_outputs += stat.count
152+
else:
153+
num_ops += stat.count
154+
for v in stat.dtype:
155+
if v is not None:
156+
dtypes.add(v)
157+
158+
arg_types = get_argument_types(model_class, "forward")
159+
num_params = 0
160+
param_dtypes = set()
161+
for name, arg_type in arg_types.items():
162+
if arg_type == torch.nn.parameter.Parameter:
163+
count = math.prod(input_dict[name].shape)
164+
# print(f"Parameter {name}: {count}")
165+
num_params += count
166+
param_dtypes.add(str(input_dict[name].dtype).replace("torch.", ""))
167+
num_params_in_billion = num_params / 1e9
168+
169+
dtypes_str = "[" + ",".join(dtypes) + "]"
170+
param_dtypes_str = "[" + ",".join(param_dtypes) + "]"
171+
print(
172+
f"{log_prompt} [ModelStats] model_path:{model_path} num_inputs:{num_inputs} num_outputs:{num_outputs} num_ops:{num_ops} num_params:{num_params_in_billion}B param_dtypes:{param_dtypes_str} op_dtypes:{dtypes_str}",
173+
file=sys.stderr,
174+
flush=True,
175+
)
176+
177+
178+
def main(args):
179+
if args.model_path is not None:
180+
assert os.path.isdir(args.model_path)
181+
assert is_single_model_dir(args.model_path)
182+
collect_model_stats(args.model_path, args.device, args.log_prompt)
183+
else:
184+
graph_net_samples_path = (
185+
(graph_net.torch.samples_util.get_default_samples_directory())
186+
if args.graph_net_samples_path is None
187+
else args.graph_net_samples_path
188+
)
189+
for root, dirs, files in os.walk(graph_net_samples_path):
190+
if is_single_model_dir(root):
191+
collect_model_stats(root, args.device, args.log_prompt)
192+
193+
194+
if __name__ == "__main__":
195+
parser = argparse.ArgumentParser(
196+
description="Validate a computation graph sample. return 0 if success"
197+
)
198+
parser.add_argument(
199+
"--device",
200+
type=str,
201+
required=False,
202+
default="cuda",
203+
help="Device for testing the compiler (e.g., 'cpu' or 'cuda')",
204+
)
205+
parser.add_argument(
206+
"--model-path",
207+
type=str,
208+
required=False,
209+
default=None,
210+
help="Computation graph sample directory. e.g '../../samples/torch/resnet18'",
211+
)
212+
parser.add_argument(
213+
"--graph-net-samples-path",
214+
type=str,
215+
required=False,
216+
default=None,
217+
help="GraphNet samples directory. e.g '../../samples'",
218+
)
219+
parser.add_argument(
220+
"--log-prompt",
221+
type=str,
222+
required=False,
223+
default="graph-net-collect-stats-log",
224+
help="Log prompt for stats log filtering.",
225+
)
226+
args = parser.parse_args()
227+
main(args=args)

graph_net/torch/test_compiler.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from . import utils
21
import argparse
32
import importlib.util
43
import inspect
@@ -14,6 +13,8 @@
1413
import json
1514
import numpy as np
1615
import platform
16+
17+
from graph_net.torch import utils
1718
from graph_net.torch.backend.graph_compiler_backend import GraphCompilerBackend
1819
from graph_net.torch.backend.tvm_backend import TvmBackend
1920
from graph_net.torch.backend.xla_backend import XlaBackend

0 commit comments

Comments
 (0)