Skip to content

Commit 9cf8a86

Browse files
committed
Add source and heuristic_tag.
1 parent 9f3086a commit 9cf8a86

File tree

1 file changed

+39
-1
lines changed

1 file changed

+39
-1
lines changed

graph_net/torch/collect_stats.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import sys
44
import ast
55
import math
6+
import json
67
import importlib
78
import inspect
89
import subprocess
@@ -60,6 +61,36 @@ def get_number_of_returns(file_path, class_name, func_name):
6061
return 0
6162

6263

64+
def read_graph_source_and_tag(model_path):
65+
try:
66+
with open(os.path.join(model_path, "graph_net.json"), "r") as f:
67+
data = json.load(f)
68+
return data["source"], data["heuristic_tag"]
69+
except Exception:
70+
if "cosyvoice" in model_path:
71+
return "cosyvoice", "audio"
72+
elif "torchaudio" in model_path:
73+
return "torchaudio", "audio"
74+
elif "ultralytics" in model_path:
75+
return "ultralytics", "computer_vision"
76+
elif "torchvision" in model_path:
77+
return "torchvision", "computer_vision"
78+
elif "timm" in model_path:
79+
return "timm", "computer_vision"
80+
elif "mmseg" in model_path:
81+
return "mmseg", "computer_vision"
82+
elif "mmpose" in model_path:
83+
return "mmpose", "computer_vision"
84+
elif "torchgeometric" in model_path:
85+
return "torchgeometric", "other"
86+
elif "transformers-auto-model" in model_path:
87+
return "huggingface_hub", "unknown"
88+
elif "nemo" in model_path:
89+
return "nemo", "unknown"
90+
else:
91+
return "unknown", "unknown"
92+
93+
6394
def get_input_dict(model_path, device):
6495
inputs_params = utils.load_converted_from_text(f"{model_path}")
6596
params = inputs_params["weight_info"]
@@ -456,6 +487,8 @@ def collect_model_stats(model_path, device, log_prompt):
456487
model_size_in_billion = model_size / 1e9
457488
num_inputs = len(argument_name2types) - num_params
458489

490+
source, heuristic_tag = read_graph_source_and_tag(model_path)
491+
459492
def dict_to_string(d):
460493
kv_list = [f"{k}:{v}" for k, v in d.items()]
461494
return " ".join(kv_list)
@@ -475,6 +508,8 @@ def print_with_log_prompt(key, value):
475508
print_with_log_prompt("param_dtypes", dict_to_string(param_dtypes))
476509
print_with_log_prompt("op_dtypes", dict_to_string(op_dtypes))
477510
print_with_log_prompt("ops", dict_to_string(ops_count_dict))
511+
print_with_log_prompt("source", source)
512+
print_with_log_prompt("heuristic_tag", heuristic_tag)
478513
print_with_log_prompt("method", method)
479514
print_with_log_prompt("is_complete", is_complete)
480515

@@ -505,7 +540,10 @@ def main(args):
505540

506541
i = 0
507542
for root, dirs, files in os.walk(graph_net_samples_path):
508-
if is_single_model_dir(root) and root in previous_failed_model_pathes:
543+
if is_single_model_dir(root) and (
544+
args.previous_collect_result_path is None
545+
or root in previous_failed_model_pathes
546+
):
509547
print(f"[{i}] Collect information for {root}")
510548
cmd = [
511549
"python",

0 commit comments

Comments
 (0)