Skip to content

Commit 110c7a9

Browse files
committed
Merge branch 'collect_info' into add_cv_samples_5_need_fix
2 parents 593d460 + 5db7117 commit 110c7a9

File tree

5 files changed

+853
-3
lines changed

5 files changed

+853
-3
lines changed

graph_net/collect_stats_util.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
import ast
2+
import importlib
3+
import inspect
4+
from dataclasses import dataclass, field
5+
from typing import Dict
6+
7+
8+
@dataclass
9+
class OpStat:
10+
op_name: str
11+
op_dtypes: dict[str, int] = field(default_factory=dict)
12+
count: int = 0
13+
14+
def update(self, other):
15+
if isinstance(other, OpStat) and self.op_name == other.op_name:
16+
self.count += other.count
17+
for name, count in other.op_dtypes.items():
18+
self.op_dtypes[name] = self.op_dtypes.get(name, 0) + count
19+
20+
21+
@dataclass
22+
class ModelStats:
23+
model_path: str
24+
num_inputs: int = None
25+
num_params: int = None
26+
num_outputs: int = None
27+
num_ops: int = None
28+
model_size_in_billion: float = None
29+
input_dtypes: Dict[str, int] = field(default_factory=dict)
30+
param_dtypes: Dict[str, int] = field(default_factory=dict)
31+
op_dtypes: Dict[str, int] = field(default_factory=dict)
32+
ops: Dict[str, int] = field(default_factory=dict)
33+
source: str = None
34+
heuristic_tag: str = None
35+
36+
37+
def print_model_stats(stats, log_prompt):
38+
assert isinstance(stats, ModelStats), f"{type(stats)=}"
39+
40+
def dict_to_string(d):
41+
kv_list = [f"{k}:{v}" for k, v in d.items()]
42+
return " ".join(kv_list)
43+
44+
def print_with_log_prompt(key, value):
45+
print(
46+
f"{log_prompt} [ModelStats.{key}] model_path:{stats.model_path} {value}",
47+
flush=True,
48+
)
49+
50+
print_with_log_prompt("num_inputs", stats.num_inputs)
51+
print_with_log_prompt("num_params", stats.num_params)
52+
print_with_log_prompt("num_outputs", stats.num_outputs)
53+
print_with_log_prompt("num_ops", stats.num_ops)
54+
print_with_log_prompt("model_size", f"{stats.model_size_in_billion}B")
55+
print_with_log_prompt("input_dtypes", dict_to_string(stats.input_dtypes))
56+
print_with_log_prompt("param_dtypes", dict_to_string(stats.param_dtypes))
57+
print_with_log_prompt("op_dtypes", dict_to_string(stats.op_dtypes))
58+
print_with_log_prompt("ops", dict_to_string(stats.ops))
59+
print_with_log_prompt("source", stats.source)
60+
print_with_log_prompt("heuristic_tag", stats.heuristic_tag)
61+
62+
63+
def load_class_from_file(file_path, class_name):
64+
spec = importlib.util.spec_from_file_location("unnamed", file_path)
65+
unnamed = importlib.util.module_from_spec(spec)
66+
spec.loader.exec_module(unnamed)
67+
model_class = getattr(unnamed, class_name, None)
68+
return model_class
69+
70+
71+
def get_argument_name_and_types(model_class, func_name):
72+
argument_name2types = {}
73+
for name, func in inspect.getmembers(model_class, predicate=inspect.isfunction):
74+
if name == func_name:
75+
for arg_name, arg in inspect.signature(func).parameters.items():
76+
if arg_name != "self":
77+
argument_name2types[arg_name] = (
78+
None if arg.annotation is inspect._empty else arg.annotation
79+
)
80+
return argument_name2types
81+
82+
83+
def get_number_of_returns(file_path, class_name, func_name):
84+
source = None
85+
with open(f"{file_path}", "r") as f:
86+
source = f.read()
87+
88+
tree = ast.parse(source)
89+
for node in tree.body:
90+
if isinstance(node, ast.ClassDef) and node.name == class_name:
91+
for f in node.body:
92+
if isinstance(f, ast.FunctionDef) and f.name == func_name:
93+
for stmt in ast.walk(f):
94+
if isinstance(stmt, ast.Return):
95+
if stmt.value is None:
96+
return 0
97+
elif isinstance(stmt.value, ast.Tuple):
98+
return len(stmt.value.elts)
99+
else:
100+
return 1
101+
return 0

graph_net/paddle/collect_stats.py

Lines changed: 297 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,297 @@
1+
import argparse
2+
import os
3+
import re
4+
import sys
5+
import math
6+
import subprocess
7+
from datetime import datetime
8+
9+
import paddle
10+
from graph_net import collect_stats_util
11+
from graph_net.paddle import utils
12+
13+
14+
def is_single_model_dir(model_dir):
15+
return os.path.isfile(f"{model_dir}/graph_net.json")
16+
17+
18+
def read_graph_source_and_tag(model_path):
19+
try:
20+
with open(os.path.join(model_path, "graph_net.json"), "r") as f:
21+
data = json.load(f)
22+
return data["source"], data["heuristic_tag"]
23+
except Exception:
24+
if "PaddleX" in model_path:
25+
return "PaddleX", "computer_vision"
26+
elif "PaddleNLP" in model_path:
27+
return "PaddleNLP", "nlp"
28+
elif "PaddleScience" in model_path:
29+
return "PaddleScience", "scientific_computing"
30+
else:
31+
return "unknown", "unknown"
32+
33+
34+
def get_input_spec(model_path):
35+
inputs_params_list = utils.load_converted_list_from_text(f"{model_path}")
36+
input_spec = [None] * len(inputs_params_list)
37+
for i, v in enumerate(inputs_params_list):
38+
dtype = v["info"]["dtype"]
39+
shape = v["info"]["shape"]
40+
input_spec[i] = paddle.static.InputSpec(shape, dtype)
41+
return input_spec
42+
43+
44+
class ProgramAnalyzer:
45+
def __init__(self):
46+
self.op_stats = {}
47+
self.input_dict = {}
48+
self.num_ops = 0
49+
self.num_ops_misses_dtypes = 0
50+
self.is_complete = True
51+
52+
def update_op_stats(self, op_name, op_dtype):
53+
if op_name is not None:
54+
dtype_str = str(op_dtype).replace("paddle.", "")
55+
if self.op_stats.get(op_name, None) is None:
56+
self.op_stats[op_name] = collect_stats_util.OpStat(
57+
op_name, {dtype_str: 1}, 1
58+
)
59+
else:
60+
self.op_stats[op_name].op_dtypes[dtype_str] = (
61+
self.op_stats[op_name].op_dtypes.get(dtype_str, 0) + 1
62+
)
63+
self.op_stats[op_name].count += 1
64+
65+
def parse_pir_value_dtypes(self, type_str):
66+
short_form2dtype = {
67+
"f32": "float32",
68+
"f16": "float16",
69+
"bf16": "bfloat16",
70+
"i64": "int64",
71+
}
72+
# type_str: "vec[tensor<1x18x13x9xf32>,tensor<1x9x13x9xf32>]"
73+
matches = re.findall(r"tensor<([^>]+)>", type_str)
74+
dtype_strs = []
75+
for s in matches:
76+
parts = s.split("x")
77+
assert len(parts) > 0
78+
79+
dtype = parts[-1].lower()
80+
dtype_strs.append(short_form2dtype[dtype])
81+
return dtype_strs
82+
83+
def __call__(self, program):
84+
assert isinstance(program, paddle.base.libpaddle.pir.Program)
85+
86+
self.op_stats = {}
87+
self.num_ops_misses_dtypes = 0
88+
self.num_ops = 0
89+
for block in program.blocks:
90+
for op in block.ops:
91+
op_name = None
92+
op_dtype = None
93+
if op.name() == "pd_op.data":
94+
op_name = "data"
95+
op_attrs = op.attrs()
96+
op_dtype = op_attrs["dtype"]
97+
self.input_dict[op_attrs["name"]] = {
98+
"dtype": str(op_dtype).replace("paddle.", ""),
99+
"shape": op_attrs["shape"],
100+
}
101+
elif op.name().startswith("pd_op."):
102+
self.num_ops += 1
103+
op_name = op.name().replace("pd_op.", "")
104+
try:
105+
if len(op.results()) > 0:
106+
out = op.results()[0]
107+
if out.is_dense_tensor_type():
108+
op_dtype = out.dtype
109+
else:
110+
# for paddle.base.libpaddle.pir.VectorType, but cannot be accurately determined
111+
if op_name in [
112+
"split",
113+
"split_with_num",
114+
"meshgrid",
115+
"distribute_fpn_proposals",
116+
]:
117+
op_dtype = self.parse_pir_value_dtypes(
118+
str(out.type())
119+
)[0]
120+
else:
121+
assert False, f"Unsupport op: {op}"
122+
except Exception:
123+
if self.num_ops_misses_dtypes == 0:
124+
print(f"dtype inference failed for {op_name}")
125+
if op_dtype is not None:
126+
self.update_op_stats(op_name, op_dtype)
127+
else:
128+
self.num_ops_misses_dtypes += 1
129+
elif not op.name().startswith("builtin."):
130+
assert False, f"Unrecognized op: {op}"
131+
132+
if self.num_ops_misses_dtypes > 0:
133+
self.is_complete = False
134+
135+
def summary(self):
136+
print(
137+
f"Totally {self.num_ops} operators, and {self.num_ops_misses_dtypes} operators failed to inference dtypes."
138+
)
139+
140+
141+
def collect_op_stats(model, model_path):
142+
assert isinstance(model, paddle.nn.Layer), f"{type(model)=}"
143+
try:
144+
static_model = paddle.jit.to_static(
145+
model,
146+
input_spec=get_input_spec(model_path),
147+
full_graph=True,
148+
backend=None,
149+
)
150+
static_model.eval()
151+
program = static_model.forward.concrete_program.main_program
152+
153+
program_analyzer = ProgramAnalyzer()
154+
program_analyzer(program)
155+
program_analyzer.summary()
156+
return program_analyzer
157+
except Exception:
158+
print("Failed with to_static")
159+
return None
160+
161+
162+
def collect_model_stats(model_path, log_prompt):
163+
file_path = os.path.join(model_path, "model.py")
164+
model_class = collect_stats_util.load_class_from_file(file_path, "GraphModule")
165+
model = model_class()
166+
167+
model_size = 0
168+
input_dtypes = {}
169+
param_dtypes = {}
170+
ops_count_dict = {}
171+
op_dtypes = {}
172+
173+
program_analyzer = collect_op_stats(model, model_path)
174+
if program_analyzer is not None:
175+
for op_name, stat in sorted(program_analyzer.op_stats.items()):
176+
ops_count_dict[op_name] = stat.count
177+
for dtype_str, num in stat.op_dtypes.items():
178+
if dtype_str is not None and dtype_str != "None":
179+
op_dtypes[dtype_str] = op_dtypes.get(dtype_str, 0) + num
180+
181+
inputs_params = utils.load_converted_from_text(f"{model_path}")
182+
params = inputs_params["weight_info"]
183+
inputs = inputs_params["input_info"]
184+
185+
for name, value in program_analyzer.input_dict.items():
186+
dtype_str = value["dtype"]
187+
if name in params.keys():
188+
param_numel = math.prod(value["shape"])
189+
model_size += param_numel
190+
param_dtypes[dtype_str] = param_dtypes.get(dtype_str, 0) + 1
191+
elif name in inputs.keys():
192+
input_dtypes[dtype_str] = input_dtypes.get(dtype_str, 0) + 1
193+
194+
num_outputs = collect_stats_util.get_number_of_returns(
195+
file_path, "GraphModule", "forward"
196+
)
197+
num_ops = program_analyzer.num_ops if program_analyzer is not None else 0
198+
source, heuristic_tag = read_graph_source_and_tag(model_path)
199+
is_complete = (
200+
program_analyzer.is_complete if program_analyzer is not None else False
201+
)
202+
print(
203+
f"model_stats collection information: model_path={model_path}, method=to_static, is_ops_complete={is_complete}"
204+
)
205+
206+
stats = collect_stats_util.ModelStats(
207+
model_path=model_path,
208+
num_inputs=sum(input_dtypes.values()),
209+
num_params=sum(param_dtypes.values()),
210+
num_outputs=num_outputs,
211+
num_ops=num_ops,
212+
model_size_in_billion=model_size / 1e9,
213+
input_dtypes=input_dtypes,
214+
param_dtypes=param_dtypes,
215+
op_dtypes=op_dtypes,
216+
ops=ops_count_dict,
217+
source=source,
218+
heuristic_tag=heuristic_tag,
219+
)
220+
collect_stats_util.print_model_stats(stats, log_prompt)
221+
222+
223+
def main(args):
224+
if args.model_path is not None:
225+
assert os.path.isdir(args.model_path)
226+
assert is_single_model_dir(args.model_path)
227+
timestamp_sec = datetime.now().timestamp()
228+
dt = datetime.fromtimestamp(timestamp_sec)
229+
formatted_dt = dt.strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]
230+
print(f"[{formatted_dt}] Collect information for {args.model_path}")
231+
collect_model_stats(args.model_path, args.log_prompt)
232+
else:
233+
graph_net_samples_path = (
234+
(graph_net.paddle.samples_util.get_default_samples_directory())
235+
if args.graph_net_samples_path is None
236+
else args.graph_net_samples_path
237+
)
238+
239+
i = 0
240+
for root, dirs, files in os.walk(graph_net_samples_path):
241+
if is_single_model_dir(root):
242+
print(f"[{i}] Collect information for {root}")
243+
cmd = [
244+
"python",
245+
"-m",
246+
"graph_net.paddle.collect_stats",
247+
f"--device={args.device}",
248+
f"--model-path={root}",
249+
f"--log-prompt={args.log_prompt}",
250+
]
251+
result = subprocess.run(
252+
cmd,
253+
stdout=subprocess.PIPE,
254+
stderr=subprocess.PIPE,
255+
text=True,
256+
timeout=600,
257+
)
258+
print(result.stdout)
259+
if result.returncode != 0:
260+
print(result.stderr)
261+
i += 1
262+
263+
264+
if __name__ == "__main__":
265+
parser = argparse.ArgumentParser(
266+
description="Collect stats for computation graph samples. return 0 if success"
267+
)
268+
parser.add_argument(
269+
"--device",
270+
type=str,
271+
required=False,
272+
default="cuda",
273+
help="Device for testing the compiler (e.g., 'cpu' or 'cuda')",
274+
)
275+
parser.add_argument(
276+
"--model-path",
277+
type=str,
278+
required=False,
279+
default=None,
280+
help="Computation graph sample directory. e.g '../../paddle_samples/PaddleX/ResNet18'",
281+
)
282+
parser.add_argument(
283+
"--graph-net-samples-path",
284+
type=str,
285+
required=False,
286+
default=None,
287+
help="GraphNet samples directory. e.g '../../paddle_samples'",
288+
)
289+
parser.add_argument(
290+
"--log-prompt",
291+
type=str,
292+
required=False,
293+
default="graph-net-collect-stats-log",
294+
help="Log prompt for stats log filtering.",
295+
)
296+
args = parser.parse_args()
297+
main(args=args)

0 commit comments

Comments
 (0)