Skip to content

Commit 1a3dc30

Browse files
committed
Reorganize some codes.
1 parent 0ce778f commit 1a3dc30

File tree

2 files changed

+129
-113
lines changed

2 files changed

+129
-113
lines changed

graph_net/collect_stats_util.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
import ast
2+
import importlib
3+
import inspect
4+
from dataclasses import dataclass, field
5+
from typing import Dict
6+
7+
8+
@dataclass
9+
class OpStat:
10+
op_name: str
11+
op_dtypes: dict[str, int] = field(default_factory=dict)
12+
count: int = 0
13+
14+
def update(self, other):
15+
if isinstance(other, OpStat) and self.op_name == other.op_name:
16+
self.count += other.count
17+
for name, count in other.op_dtypes.items():
18+
self.op_dtypes[name] = self.op_dtypes.get(name, 0) + count
19+
20+
21+
@dataclass
22+
class ModelStats:
23+
model_path: str
24+
num_inputs: int = None
25+
num_params: int = None
26+
num_outputs: int = None
27+
num_ops: int = None
28+
model_size_in_billion: float = None
29+
input_dtypes: Dict[str, int] = field(default_factory=dict)
30+
param_dtypes: Dict[str, int] = field(default_factory=dict)
31+
op_dtypes: Dict[str, int] = field(default_factory=dict)
32+
ops: Dict[str, int] = field(default_factory=dict)
33+
source: str = None
34+
heuristic_tag: str = None
35+
36+
37+
def print_model_stats(stats, log_prompt):
38+
assert isinstance(stats, ModelStats), f"{type(stats)=}"
39+
40+
def dict_to_string(d):
41+
kv_list = [f"{k}:{v}" for k, v in d.items()]
42+
return " ".join(kv_list)
43+
44+
def print_with_log_prompt(key, value):
45+
print(
46+
f"{log_prompt} [ModelStats.{key}] model_path:{stats.model_path} {value}",
47+
flush=True,
48+
)
49+
50+
print_with_log_prompt("num_inputs", stats.num_inputs)
51+
print_with_log_prompt("num_params", stats.num_params)
52+
print_with_log_prompt("num_outputs", stats.num_outputs)
53+
print_with_log_prompt("num_ops", stats.num_ops)
54+
print_with_log_prompt("model_size", f"{stats.model_size_in_billion}B")
55+
print_with_log_prompt("input_dtypes", dict_to_string(stats.input_dtypes))
56+
print_with_log_prompt("param_dtypes", dict_to_string(stats.param_dtypes))
57+
print_with_log_prompt("op_dtypes", dict_to_string(stats.op_dtypes))
58+
print_with_log_prompt("ops", dict_to_string(stats.ops))
59+
print_with_log_prompt("source", stats.source)
60+
print_with_log_prompt("heuristic_tag", stats.heuristic_tag)
61+
62+
63+
def load_class_from_file(file_path, class_name):
64+
spec = importlib.util.spec_from_file_location("unnamed", file_path)
65+
unnamed = importlib.util.module_from_spec(spec)
66+
spec.loader.exec_module(unnamed)
67+
model_class = getattr(unnamed, class_name, None)
68+
return model_class
69+
70+
71+
def get_argument_name_and_types(model_class, func_name):
72+
argument_name2types = {}
73+
for name, func in inspect.getmembers(model_class, predicate=inspect.isfunction):
74+
if name == func_name:
75+
for arg_name, arg in inspect.signature(func).parameters.items():
76+
if arg_name != "self":
77+
argument_name2types[arg_name] = (
78+
None if arg.annotation is inspect._empty else arg.annotation
79+
)
80+
return argument_name2types
81+
82+
83+
def get_number_of_returns(file_path, class_name, func_name):
84+
source = None
85+
with open(f"{file_path}", "r") as f:
86+
source = f.read()
87+
88+
tree = ast.parse(source)
89+
for node in tree.body:
90+
if isinstance(node, ast.ClassDef) and node.name == class_name:
91+
for f in node.body:
92+
if isinstance(f, ast.FunctionDef) and f.name == func_name:
93+
for stmt in ast.walk(f):
94+
if isinstance(stmt, ast.Return):
95+
if stmt.value is None:
96+
return 0
97+
elif isinstance(stmt.value, ast.Tuple):
98+
return len(stmt.value.elts)
99+
else:
100+
return 1
101+
return 0

graph_net/paddle/collect_stats.py

Lines changed: 28 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -2,65 +2,19 @@
22
import os
33
import re
44
import sys
5-
import ast
65
import math
7-
import importlib
8-
import inspect
96
import subprocess
107
from datetime import datetime
11-
from typing import Type
12-
from dataclasses import dataclass, field
13-
from collections import defaultdict
148

159
import paddle
10+
from graph_net import collect_stats_util
1611
from graph_net.paddle import utils
1712

1813

1914
def is_single_model_dir(model_dir):
2015
return os.path.isfile(f"{model_dir}/graph_net.json")
2116

2217

23-
def load_class_from_file(file_path: str, class_name: str) -> Type[paddle.nn.Layer]:
24-
spec = importlib.util.spec_from_file_location("unnamed", file_path)
25-
unnamed = importlib.util.module_from_spec(spec)
26-
spec.loader.exec_module(unnamed)
27-
model_class = getattr(unnamed, class_name, None)
28-
return model_class
29-
30-
31-
def get_argument_name_and_types(model_class, func_name):
32-
argument_name2types = {}
33-
for name, func in inspect.getmembers(model_class, predicate=inspect.isfunction):
34-
if name == func_name:
35-
for arg_name, arg in inspect.signature(func).parameters.items():
36-
if arg_name != "self":
37-
argument_name2types[arg_name] = (
38-
None if arg.annotation is inspect._empty else arg.annotation
39-
)
40-
return argument_name2types
41-
42-
43-
def get_number_of_returns(file_path, class_name, func_name):
44-
source = None
45-
with open(f"{file_path}", "r") as f:
46-
source = f.read()
47-
48-
tree = ast.parse(source)
49-
for node in tree.body:
50-
if isinstance(node, ast.ClassDef) and node.name == class_name:
51-
for f in node.body:
52-
if isinstance(f, ast.FunctionDef) and f.name == func_name:
53-
for stmt in ast.walk(f):
54-
if isinstance(stmt, ast.Return):
55-
if stmt.value is None:
56-
return 0
57-
elif isinstance(stmt.value, ast.Tuple):
58-
return len(stmt.value.elts)
59-
else:
60-
return 1
61-
return 0
62-
63-
6418
def read_graph_source_and_tag(model_path):
6519
try:
6620
with open(os.path.join(model_path, "graph_net.json"), "r") as f:
@@ -87,19 +41,6 @@ def get_input_spec(model_path):
8741
return input_spec
8842

8943

90-
@dataclass
91-
class OpStat:
92-
op_name: str
93-
op_dtypes: dict[str, int] = field(default_factory=dict)
94-
count: int = 0
95-
96-
def update(self, other):
97-
if isinstance(other, OpStat) and self.op_name == other.op_name:
98-
self.count += other.count
99-
for name, count in other.op_dtypes.items():
100-
self.op_dtypes[name] = self.op_dtypes.get(name, 0) + count
101-
102-
10344
class ProgramAnalyzer:
10445
def __init__(self):
10546
self.op_stats = {}
@@ -112,7 +53,9 @@ def update_op_stats(self, op_name, op_dtype):
11253
if op_name is not None:
11354
dtype_str = str(op_dtype).replace("paddle.", "")
11455
if self.op_stats.get(op_name, None) is None:
115-
self.op_stats[op_name] = OpStat(op_name, {dtype_str: 1}, 1)
56+
self.op_stats[op_name] = collect_stats_util.OpStat(
57+
op_name, {dtype_str: 1}, 1
58+
)
11659
else:
11760
self.op_stats[op_name].op_dtypes[dtype_str] = (
11861
self.op_stats[op_name].op_dtypes.get(dtype_str, 0) + 1
@@ -213,9 +156,8 @@ def collect_op_stats(model, model_path):
213156

214157
def collect_model_stats(model_path, log_prompt):
215158
file_path = os.path.join(model_path, "model.py")
216-
model_class = load_class_from_file(file_path, "GraphModule")
159+
model_class = collect_stats_util.load_class_from_file(file_path, "GraphModule")
217160
model = model_class()
218-
num_outputs = get_number_of_returns(file_path, "GraphModule", "forward")
219161

220162
model_size = 0
221163
input_dtypes = {}
@@ -244,39 +186,33 @@ def collect_model_stats(model_path, log_prompt):
244186
elif name in inputs.keys():
245187
input_dtypes[dtype_str] = input_dtypes.get(dtype_str, 0) + 1
246188

247-
model_size_in_billion = model_size / 1e9
248-
num_params = sum(param_dtypes.values())
249-
num_inputs = sum(input_dtypes.values())
250-
num_ops = sum(ops_count_dict.values())
189+
num_outputs = collect_stats_util.get_number_of_returns(
190+
file_path, "GraphModule", "forward"
191+
)
192+
num_ops = program_analyzer.num_ops if program_analyzer is not None else 0
251193
source, heuristic_tag = read_graph_source_and_tag(model_path)
252-
method = "to_static"
253194
is_complete = (
254195
program_analyzer.is_complete if program_analyzer is not None else False
255196
)
197+
print(
198+
f"model_stats collection information: model_path={model_path}, method=to_static, is_ops_complete={is_complete}"
199+
)
256200

257-
def dict_to_string(d):
258-
kv_list = [f"{k}:{v}" for k, v in d.items()]
259-
return " ".join(kv_list)
260-
261-
def print_with_log_prompt(key, value):
262-
print(
263-
f"{log_prompt} [ModelStats.{key}] model_path:{model_path} {value}",
264-
flush=True,
265-
)
266-
267-
print_with_log_prompt("num_inputs", num_inputs)
268-
print_with_log_prompt("num_params", num_params)
269-
print_with_log_prompt("num_outputs", num_outputs)
270-
print_with_log_prompt("num_ops", num_ops)
271-
print_with_log_prompt("model_size", f"{model_size_in_billion}B")
272-
print_with_log_prompt("input_dtypes", dict_to_string(input_dtypes))
273-
print_with_log_prompt("param_dtypes", dict_to_string(param_dtypes))
274-
print_with_log_prompt("op_dtypes", dict_to_string(op_dtypes))
275-
print_with_log_prompt("ops", dict_to_string(ops_count_dict))
276-
print_with_log_prompt("source", source)
277-
print_with_log_prompt("heuristic_tag", heuristic_tag)
278-
print_with_log_prompt("method", method)
279-
print_with_log_prompt("is_complete", is_complete)
201+
stats = collect_stats_util.ModelStats(
202+
model_path=model_path,
203+
num_inputs=sum(input_dtypes.values()),
204+
num_params=sum(param_dtypes.values()),
205+
num_outputs=num_outputs,
206+
num_ops=num_ops,
207+
model_size_in_billion=model_size / 1e9,
208+
input_dtypes=input_dtypes,
209+
param_dtypes=param_dtypes,
210+
op_dtypes=op_dtypes,
211+
ops=ops_count_dict,
212+
source=source,
213+
heuristic_tag=heuristic_tag,
214+
)
215+
collect_stats_util.print_model_stats(stats, log_prompt)
280216

281217

282218
def main(args):
@@ -295,23 +231,9 @@ def main(args):
295231
else args.graph_net_samples_path
296232
)
297233

298-
previous_failed_model_pathes = []
299-
if args.previous_collect_result_path is not None:
300-
with open(args.previous_collect_result_path, "r") as f:
301-
for line in f.readlines():
302-
if "[ModelStats]" in line:
303-
fields = line.strip().split()
304-
model_path = fields[2].split(":")[-1]
305-
is_complete = fields[-1].split(":")[-1]
306-
if is_complete == "False":
307-
previous_failed_model_pathes.append(model_path)
308-
309234
i = 0
310235
for root, dirs, files in os.walk(graph_net_samples_path):
311-
if is_single_model_dir(root) and (
312-
args.previous_collect_result_path is None
313-
or root in previous_failed_model_pathes
314-
):
236+
if is_single_model_dir(root):
315237
print(f"[{i}] Collect information for {root}")
316238
cmd = [
317239
"python",
@@ -359,13 +281,6 @@ def main(args):
359281
default=None,
360282
help="GraphNet samples directory. e.g '../../paddle_samples'",
361283
)
362-
parser.add_argument(
363-
"--previous-collect-result-path",
364-
type=str,
365-
required=False,
366-
default=None,
367-
help="Previous collect result path, use to recollect the failed cases",
368-
)
369284
parser.add_argument(
370285
"--log-prompt",
371286
type=str,

0 commit comments

Comments
 (0)