Skip to content

Commit da0ff51

Browse files
committed
Merge branch 'collect_info' into add_cv_samples_5_need_fix
2 parents 110c7a9 + 6a5b8db commit da0ff51

File tree

3 files changed

+24
-14
lines changed

3 files changed

+24
-14
lines changed

graph_net/collect_stats_util.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import ast
2+
import json
23
import importlib
34
import inspect
45
from dataclasses import dataclass, field
@@ -28,6 +29,7 @@ class ModelStats:
2829
model_size_in_billion: float = None
2930
input_dtypes: Dict[str, int] = field(default_factory=dict)
3031
param_dtypes: Dict[str, int] = field(default_factory=dict)
32+
input_shapes: Dict[str, list] = field(default_factory=dict)
3133
op_dtypes: Dict[str, int] = field(default_factory=dict)
3234
ops: Dict[str, int] = field(default_factory=dict)
3335
source: str = None
@@ -37,10 +39,6 @@ class ModelStats:
3739
def print_model_stats(stats, log_prompt):
3840
assert isinstance(stats, ModelStats), f"{type(stats)=}"
3941

40-
def dict_to_string(d):
41-
kv_list = [f"{k}:{v}" for k, v in d.items()]
42-
return " ".join(kv_list)
43-
4442
def print_with_log_prompt(key, value):
4543
print(
4644
f"{log_prompt} [ModelStats.{key}] model_path:{stats.model_path} {value}",
@@ -52,10 +50,11 @@ def print_with_log_prompt(key, value):
5250
print_with_log_prompt("num_outputs", stats.num_outputs)
5351
print_with_log_prompt("num_ops", stats.num_ops)
5452
print_with_log_prompt("model_size", f"{stats.model_size_in_billion}B")
55-
print_with_log_prompt("input_dtypes", dict_to_string(stats.input_dtypes))
56-
print_with_log_prompt("param_dtypes", dict_to_string(stats.param_dtypes))
57-
print_with_log_prompt("op_dtypes", dict_to_string(stats.op_dtypes))
58-
print_with_log_prompt("ops", dict_to_string(stats.ops))
53+
print_with_log_prompt("input_dtypes", json.dumps(stats.input_dtypes))
54+
print_with_log_prompt("param_dtypes", json.dumps(stats.param_dtypes))
55+
print_with_log_prompt("input_shapes", json.dumps(stats.input_shapes))
56+
print_with_log_prompt("op_dtypes", json.dumps(stats.op_dtypes))
57+
print_with_log_prompt("ops", json.dumps(stats.ops))
5958
print_with_log_prompt("source", stats.source)
6059
print_with_log_prompt("heuristic_tag", stats.heuristic_tag)
6160

graph_net/paddle/collect_stats.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -109,10 +109,11 @@ def __call__(self, program):
109109
else:
110110
# for paddle.base.libpaddle.pir.VectorType, but cannot be accurately determined
111111
if op_name in [
112+
"broadcast_tensors",
113+
"distribute_fpn_proposals",
114+
"meshgrid",
112115
"split",
113116
"split_with_num",
114-
"meshgrid",
115-
"distribute_fpn_proposals",
116117
]:
117118
op_dtype = self.parse_pir_value_dtypes(
118119
str(out.type())
@@ -165,6 +166,7 @@ def collect_model_stats(model_path, log_prompt):
165166
model = model_class()
166167

167168
model_size = 0
169+
input_shapes = set()
168170
input_dtypes = {}
169171
param_dtypes = {}
170172
ops_count_dict = {}
@@ -190,6 +192,7 @@ def collect_model_stats(model_path, log_prompt):
190192
param_dtypes[dtype_str] = param_dtypes.get(dtype_str, 0) + 1
191193
elif name in inputs.keys():
192194
input_dtypes[dtype_str] = input_dtypes.get(dtype_str, 0) + 1
195+
input_shapes.add(str(value["shape"]))
193196

194197
num_outputs = collect_stats_util.get_number_of_returns(
195198
file_path, "GraphModule", "forward"
@@ -200,7 +203,7 @@ def collect_model_stats(model_path, log_prompt):
200203
program_analyzer.is_complete if program_analyzer is not None else False
201204
)
202205
print(
203-
f"model_stats collection information: model_path={model_path}, method=to_static, is_ops_complete={is_complete}"
206+
f"model_stats collection information: model_path={model_path} method=to_static is_ops_complete={is_complete}"
204207
)
205208

206209
stats = collect_stats_util.ModelStats(
@@ -212,6 +215,7 @@ def collect_model_stats(model_path, log_prompt):
212215
model_size_in_billion=model_size / 1e9,
213216
input_dtypes=input_dtypes,
214217
param_dtypes=param_dtypes,
218+
input_shapes=list(input_shapes),
215219
op_dtypes=op_dtypes,
216220
ops=ops_count_dict,
217221
source=source,

graph_net/torch/collect_stats.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -336,17 +336,23 @@ def collect_model_stats(model_path, device, log_prompt):
336336
op_dtypes[dtype_str] = op_dtypes.get(dtype_str, 0) + num
337337

338338
model_size = 0
339+
input_shapes = set()
339340
input_dtypes = {}
340341
param_dtypes = {}
341342
for name, arg_type in argument_name2types.items():
342-
if arg_type == torch.nn.parameter.Parameter:
343+
if (
344+
name.startswith("L_self_modules_")
345+
or arg_type == torch.nn.parameter.Parameter
346+
):
347+
# Some parameters like L_self_modules_bn1_buffers_running_mean_ are torch.Tensor.
343348
param_numel = math.prod(input_dict[name].shape)
344349
model_size += param_numel
345350
dtype_str = str(input_dict[name].dtype).replace("torch.", "")
346351
param_dtypes[dtype_str] = param_dtypes.get(dtype_str, 0) + 1
347-
else:
352+
elif arg_type == torch.Tensor:
348353
dtype_str = str(input_dict[name].dtype).replace("torch.", "")
349354
input_dtypes[dtype_str] = input_dtypes.get(dtype_str, 0) + 1
355+
input_shapes.add(str(list(input_dict[name].shape)))
350356

351357
num_outputs = collect_stats_util.get_number_of_returns(
352358
file_path, "GraphModule", "forward"
@@ -356,7 +362,7 @@ def collect_model_stats(model_path, device, log_prompt):
356362

357363
is_complete = meta_executor.is_complete if meta_executor is not None else False
358364
print(
359-
f"model_stats collection information: model_path={model_path}, method={method}, is_ops_complete={is_complete}"
365+
f"model_stats collection information: model_path={model_path} method={method} is_ops_complete={is_complete}"
360366
)
361367

362368
stats = collect_stats_util.ModelStats(
@@ -368,6 +374,7 @@ def collect_model_stats(model_path, device, log_prompt):
368374
model_size_in_billion=model_size / 1e9,
369375
input_dtypes=input_dtypes,
370376
param_dtypes=param_dtypes,
377+
input_shapes=list(input_shapes),
371378
op_dtypes=op_dtypes,
372379
ops=ops_count_dict,
373380
source=source,

0 commit comments

Comments
 (0)