Skip to content

Commit 8161df7

Browse files
committed
Fix several ops and change to use subprocess for multiple tests.
1 parent b2073f9 commit 8161df7

File tree

1 file changed

+60
-15
lines changed

1 file changed

+60
-15
lines changed

graph_net/torch/collect_stats.py

Lines changed: 60 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import math
55
import importlib
66
import inspect
7+
import subprocess
78
from typing import Type
89
from dataclasses import dataclass, field
910
from collections import defaultdict
@@ -62,11 +63,31 @@ def resolve_native_multi_head_attention(*args, **kwargs):
6263
(seq_len, batch_size, embed_dim), dtype=query.dtype, device="meta"
6364
)
6465

65-
# seq_len_k = key.shape[0]
66-
# num_heads = args[4]
67-
# attn_output_weights = torch.empty((batch_size, num_heads, seq_len, seq_len_k),
68-
# dtype=query.dtype, device='meta')
69-
return attn_output # , attn_output_weights
66+
# TODO(Xreki): get value from args
67+
need_weights = False
68+
if need_weights:
69+
seq_len_k = key.shape[0]
70+
num_heads = args[4]
71+
attn_output_weights = torch.empty(
72+
(batch_size, num_heads, seq_len, seq_len_k),
73+
dtype=query.dtype,
74+
device="meta",
75+
)
76+
return attn_output, attn_output_weights
77+
else:
78+
return attn_output
79+
80+
81+
def resolve_tensor_to(tensor, *args, **kwargs):
82+
if isinstance(args[0], torch.dtype):
83+
dtype = args[0]
84+
else:
85+
dtype = tensor.dtype
86+
return torch.empty(tensor.shape, dtype=dtype, device="meta")
87+
88+
89+
def resolve_tensor_item(tensor):
90+
return torch.empty((), dtype=tensor.dtype, device="meta")
7091

7192

7293
def resolve_get_attr(gm: torch.fx.GraphModule, node: torch.fx.Node):
@@ -115,6 +136,7 @@ def collect_op_stats(model, input_dict):
115136
)
116137

117138
try:
139+
# if True:
118140
if node.op == "call_module":
119141
# classname of module
120142
submod = traced.get_submodule(node.target)
@@ -133,8 +155,15 @@ def collect_op_stats(model, input_dict):
133155
op_func = getattr(self_obj, node.target)
134156
node_args = node_args[1:]
135157

158+
# print(f"node.op={node.op}, op_name={op_name}, node.args={node.args}")
136159
if op_name == "_native_multi_head_attention":
137160
out = resolve_native_multi_head_attention(*node_args, **node_kwargs)
161+
elif op_name == "to":
162+
out = resolve_tensor_to(
163+
node_outputs[node.args[0].name], *node_args, **node_kwargs
164+
)
165+
elif op_name == "item":
166+
out = resolve_tensor_item(node_outputs[node.args[0].name])
138167
else:
139168
out = op_func(*node_args, **node_kwargs)
140169
node_outputs[node.name] = out
@@ -172,12 +201,6 @@ def collect_op_stats(model, input_dict):
172201

173202

174203
def collect_model_stats(model_path, device, log_prompt):
175-
if not hasattr(collect_model_stats, "_counter"):
176-
collect_model_stats._counter = 0
177-
else:
178-
collect_model_stats._counter += 1
179-
print(f"[{collect_model_stats._counter}] Collect information for {model_path}")
180-
181204
model_class = load_class_from_file(
182205
os.path.join(model_path, "model.py"), "GraphModule"
183206
)
@@ -187,16 +210,18 @@ def collect_model_stats(model_path, device, log_prompt):
187210
num_ops = 0
188211
num_inputs = 0
189212
num_outputs = 0
213+
ops_count_info = []
190214
dtypes = set()
191215
is_complete, op_stats = collect_op_stats(model, input_dict)
192216
if op_stats is not None:
193-
for op_name, stat in op_stats.items():
217+
for op_name, stat in sorted(op_stats.items()):
194218
if op_name == "placeholder":
195219
num_inputs += stat.count
196220
elif op_name == "output":
197221
num_outputs += stat.count
198222
else:
199223
num_ops += stat.count
224+
ops_count_info.append(f"{op_name}={stat.count}")
200225
for v in stat.dtype:
201226
if v is not None:
202227
dtypes.add(v)
@@ -213,11 +238,11 @@ def collect_model_stats(model_path, device, log_prompt):
213238
param_dtypes.add(str(input_dict[name].dtype).replace("torch.", ""))
214239
num_params_in_billion = num_params / 1e9
215240

241+
ops_str = "[" + ",".join(ops_count_info) + "]"
216242
dtypes_str = "[" + ",".join(dtypes) + "]"
217243
param_dtypes_str = "[" + ",".join(param_dtypes) + "]"
218244
print(
219-
f"{log_prompt} [ModelStats] model_path:{model_path} num_inputs:{num_inputs} num_outputs:{num_outputs} num_ops:{num_ops} num_params:{num_params_in_billion}B param_dtypes:{param_dtypes_str} op_dtypes:{dtypes_str} is_complete:{is_complete}",
220-
file=sys.stderr,
245+
f"{log_prompt} [ModelStats] model_path:{model_path} num_inputs:{num_inputs} num_outputs:{num_outputs} num_ops:{num_ops} num_params:{num_params_in_billion}B param_dtypes:{param_dtypes_str} op_dtypes:{dtypes_str} is_complete:{is_complete} ops:{ops_str}",
221246
flush=True,
222247
)
223248

@@ -226,16 +251,36 @@ def main(args):
226251
if args.model_path is not None:
227252
assert os.path.isdir(args.model_path)
228253
assert is_single_model_dir(args.model_path)
254+
print(f"Collect information for {args.model_path}")
229255
collect_model_stats(args.model_path, args.device, args.log_prompt)
230256
else:
231257
graph_net_samples_path = (
232258
(graph_net.torch.samples_util.get_default_samples_directory())
233259
if args.graph_net_samples_path is None
234260
else args.graph_net_samples_path
235261
)
262+
i = 0
236263
for root, dirs, files in os.walk(graph_net_samples_path):
237264
if is_single_model_dir(root):
238-
collect_model_stats(root, args.device, args.log_prompt)
265+
print(f"[{i}] Collect information for {root}")
266+
cmd = [
267+
"python",
268+
"-m",
269+
"graph_net.torch.collect_stats",
270+
f"--device={args.device}",
271+
f"--model-path={root}",
272+
f"--log-prompt={args.log_prompt}",
273+
]
274+
result = subprocess.run(
275+
cmd,
276+
stdout=subprocess.PIPE,
277+
stderr=subprocess.PIPE,
278+
text=True,
279+
timeout=600,
280+
)
281+
if result.returncode == 0:
282+
print(result.stdout)
283+
i += 1
239284

240285

241286
if __name__ == "__main__":

0 commit comments

Comments
 (0)