Skip to content

Commit 256c75f

Browse files
committed
Optimize the dtypes stats.
1 parent 07558f2 commit 256c75f

File tree

1 file changed

+50
-28
lines changed

1 file changed

+50
-28
lines changed

graph_net/torch/collect_stats.py

Lines changed: 50 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def get_input_dict(model_path, device):
5252
@dataclass
5353
class OpStat:
5454
op_name: str
55-
dtype: set[str] = field(default_factory=set)
55+
op_dtypes: dict[str, int] = field(default_factory=dict)
5656
count: int = 0
5757

5858

@@ -124,7 +124,6 @@ def collect_op_stats_manual(model, input_dict):
124124
if node.op == "placeholder":
125125
node_outputs[node.name] = meta_input_dict[node.target]
126126
op_name = node.op
127-
dtype = node_outputs[node.name].dtype
128127
elif node.op in ["call_function", "call_module", "call_method"]:
129128
node_args = torch.fx.map_arg(
130129
node.args,
@@ -190,11 +189,13 @@ def collect_op_stats_manual(model, input_dict):
190189
assert False, f"node.op: {node.op}"
191190

192191
if op_name is not None:
193-
dtype_str = str(dtype).replace("torch.", "") if dtype is not None else None
192+
dtype_str = str(dtype).replace("torch.", "")
194193
if op_stats.get(op_name, None) is None:
195-
op_stats[op_name] = OpStat(op_name, {dtype_str}, 1)
194+
op_stats[op_name] = OpStat(op_name, {dtype_str: 1}, 1)
196195
else:
197-
op_stats[op_name].dtype.add(dtype_str)
196+
op_stats[op_name].op_dtypes[dtype_str] = (
197+
op_stats[op_name].op_dtypes.get(dtype_str, 0) + 1
198+
)
198199
op_stats[op_name].count = op_stats[op_name].count + 1
199200
return is_complete, op_stats
200201

@@ -234,7 +235,7 @@ def collect_op_stats_with_make_fx(model, input_dict, arg_types):
234235
assert False, f"node.op: {node.op}"
235236

236237
dtype = None
237-
if node.op != "output":
238+
if node.op not in ["placeholder", "output"]:
238239
if "tensor_meta" in node.meta:
239240
tensor_meta = node.meta["tensor_meta"]
240241
dtype = tensor_meta.dtype
@@ -252,9 +253,11 @@ def collect_op_stats_with_make_fx(model, input_dict, arg_types):
252253
)
253254
dtype_str = str(dtype).replace("torch.", "")
254255
if op_stats.get(op_name, None) is None:
255-
op_stats[op_name] = OpStat(op_name, {dtype_str}, 1)
256+
op_stats[op_name] = OpStat(op_name, {dtype_str: 1}, 1)
256257
else:
257-
op_stats[op_name].dtype.add(dtype_str)
258+
op_stats[op_name].op_dtypes[dtype_str] = (
259+
op_stats[op_name].op_dtypes.get(dtype_str, 0) + 1
260+
)
258261
op_stats[op_name].count = op_stats[op_name].count + 1
259262
return is_complete, op_stats
260263

@@ -280,8 +283,8 @@ def collect_model_stats(model_path, device, log_prompt):
280283

281284
num_ops = 0
282285
num_outputs = 0
283-
ops_count_info = []
284-
dtypes = set()
286+
ops_count_dict = {}
287+
op_dtypes = {}
285288
method, is_complete, op_stats = collect_op_stats(model, input_dict, arg_types)
286289
if op_stats is not None:
287290
for op_name, stat in sorted(op_stats.items()):
@@ -291,29 +294,48 @@ def collect_model_stats(model_path, device, log_prompt):
291294
num_outputs += stat.count
292295
else:
293296
num_ops += stat.count
294-
ops_count_info.append(f"{op_name}={stat.count}")
295-
for v in stat.dtype:
296-
if v is not None:
297-
dtypes.add(v)
297+
ops_count_dict[op_name] = stat.count
298+
for dtype_str, num in stat.op_dtypes.items():
299+
if dtype_str is not None and dtype_str != "None":
300+
op_dtypes[dtype_str] = op_dtypes.get(dtype_str, 0) + num
298301

299-
num_inputs = len(arg_types)
300302
num_params = 0
301-
param_dtypes = set()
303+
model_size = 0
304+
input_dtypes = {}
305+
param_dtypes = {}
302306
for name, arg_type in arg_types.items():
303307
if arg_type == torch.nn.parameter.Parameter:
304-
count = math.prod(input_dict[name].shape)
308+
param_numel = math.prod(input_dict[name].shape)
305309
# print(f"Parameter {name}: {count}")
306-
num_params += count
307-
param_dtypes.add(str(input_dict[name].dtype).replace("torch.", ""))
308-
num_params_in_billion = num_params / 1e9
309-
310-
ops_str = "[" + ",".join(ops_count_info) + "]"
311-
dtypes_str = "[" + ",".join(dtypes) + "]"
312-
param_dtypes_str = "[" + ",".join(param_dtypes) + "]"
313-
print(
314-
f"{log_prompt} [ModelStats] model_path:{model_path} num_inputs:{num_inputs} num_outputs:{num_outputs} num_ops:{num_ops} num_params:{num_params_in_billion}B param_dtypes:{param_dtypes_str} op_dtypes:{dtypes_str} method:{method} is_complete:{is_complete} ops:{ops_str}",
315-
flush=True,
316-
)
310+
num_params += 1
311+
model_size += param_numel
312+
dtype_str = str(input_dict[name].dtype).replace("torch.", "")
313+
param_dtypes[dtype_str] = param_dtypes.get(dtype_str, 0) + 1
314+
else:
315+
dtype_str = str(input_dict[name].dtype).replace("torch.", "")
316+
input_dtypes[dtype_str] = input_dtypes.get(dtype_str, 0) + 1
317+
model_size_in_billion = model_size / 1e9
318+
num_inputs = len(arg_types) - num_params
319+
320+
def dict_to_string(d):
321+
kv_list = [f"{k}={v}" for k, v in d.items()]
322+
return "{" + ",".join(kv_list) + "}"
323+
324+
log_fields = [log_prompt, "[ModelStats]"]
325+
log_fields.append(f"model_path:{model_path}")
326+
log_fields.append(f"num_inputs:{num_inputs}")
327+
log_fields.append(f"num_params:{num_params}")
328+
log_fields.append(f"num_outputs:{num_outputs}")
329+
log_fields.append(f"num_ops:{num_ops}")
330+
log_fields.append(f"model_size:{model_size_in_billion}B")
331+
log_fields.append(f"input_dtypes:{dict_to_string(input_dtypes)}")
332+
log_fields.append(f"param_dtypes:{dict_to_string(param_dtypes)}")
333+
log_fields.append(f"op_dtypes:{dict_to_string(op_dtypes)}")
334+
log_fields.append(f"ops:{dict_to_string(ops_count_dict)}")
335+
log_fields.append(f"method:{method}")
336+
log_fields.append(f"is_complete:{is_complete}")
337+
338+
print(" ".join(log_fields), flush=True)
317339

318340

319341
def main(args):

0 commit comments

Comments
 (0)