Skip to content

Commit 2234cb1

Browse files
authored
Add some Scripts (#146)
* add functional tools * change model test files * change model test files * add work space * add timm
1 parent 9a30aa7 commit 2234cb1

File tree

4 files changed

+92
-5
lines changed

4 files changed

+92
-5
lines changed

graph_net/test/timm_model_test.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# pip install timm
2+
3+
import argparse
4+
import os
5+
import torch
6+
from torchvision import transforms
7+
import timm # 导入 timm 库
8+
import graph_net
9+
10+
os.environ["TIMMDL_DISABLE_RETRY"] = "1" # 禁用重试
11+
12+
13+
def extract_visio_graph(model_name, model_path, dynamic_mode):
14+
# Normalization parameters for ImageNet
15+
normalize = transforms.Normalize(
16+
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
17+
)
18+
19+
# Create dummy input
20+
batch_size = 1
21+
height, width = 224, 224 # Standard ImageNet size
22+
num_channels = 3
23+
random_input = torch.rand(batch_size, num_channels, height, width)
24+
normalized_input = normalize(random_input)
25+
26+
# Instantiate model using timm
27+
model = timm.create_model(model_path, pretrained=False) # 使用 timm 加载 resnet18
28+
model.eval()
29+
30+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31+
model.to(device)
32+
normalized_input = normalized_input.to(device)
33+
34+
# Extract graph structure
35+
model = graph_net.torch.extract(name=model_name, dynamic=dynamic_mode)(model)
36+
37+
print("Running inference...")
38+
output = model(normalized_input)
39+
print("Inference finished. Output shape:", output.shape)
40+
41+
42+
if __name__ == "__main__":
43+
# get parameters from command line
44+
workspace_default = os.environ.get("GRAPH_NET_EXTRACT_WORKSPACE", "workspace")
45+
46+
parser = argparse.ArgumentParser()
47+
parser.add_argument("--model_name", type=str, default="resnet18")
48+
parser.add_argument("--model_path", type=str, default="resnet18") # timm 模型名称
49+
parser.add_argument("--workspace", type=str, default=workspace_default)
50+
parser.add_argument("--dynamic", type=bool, default=True)
51+
args = parser.parse_args()
52+
53+
os.environ["GRAPH_NET_EXTRACT_WORKSPACE"] = args.workspace
54+
55+
extract_visio_graph(args.model_name, args.model_path, args.dynamic)

graph_net/test/vision_model_test.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
import argparse
22
import os
3-
import json
3+
44
import torch
55
import torchvision
66
from torchvision import transforms
7-
import os
87
import graph_net
98

10-
if __name__ == "__main__":
9+
10+
def extract_visio_graph(model_name, model_path):
1111
# Normalization parameters for ImageNet
1212
normalize = transforms.Normalize(
1313
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
@@ -21,16 +21,32 @@
2121
normalized_input = normalize(random_input)
2222

2323
# Instantiate model
24-
model = torchvision.models.get_model("resnet18", weights="DEFAULT")
24+
model = torchvision.models.get_model(model_path, weights="DEFAULT")
2525
model.eval()
2626

2727
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
2828
model.to(device)
2929
normalized_input = normalized_input.to(device)
3030

31-
model = graph_net.torch.extract(name="resnet18", dynamic=True)(model)
31+
model = graph_net.torch.extract(name=model_name, dynamic=True)(model)
3232

3333
print("Running inference...")
3434
print("Input shape:", normalized_input.shape)
3535
output = model(normalized_input)
3636
print("Inference finished. Output shape:", output.shape)
37+
38+
39+
if __name__ == "__main__":
40+
# get parameters from command line
41+
workspace_default = os.environ.get("GRAPH_NET_EXTRACT_WORKSPACE", "workspace")
42+
43+
parser = argparse.ArgumentParser()
44+
parser.add_argument("--model_name", type=str, default="resnet18")
45+
parser.add_argument("--model_path", type=str, default="resnet18") # timm 模型名称
46+
parser.add_argument("--workspace", type=str, default=workspace_default)
47+
parser.add_argument("--dynamic", type=bool, default=True)
48+
args = parser.parse_args()
49+
50+
os.environ["GRAPH_NET_EXTRACT_WORKSPACE"] = args.workspace
51+
52+
extract_visio_graph(args.model_name, args.model_path)

graph_net/torch/validate.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,5 +66,12 @@ def main(args):
6666
action="store_true",
6767
help="whether check model graph redundancy",
6868
)
69+
parser.add_argument(
70+
"--workspace",
71+
default=os.environ.get("GRAPH_NET_EXTRACT_WORKSPACE", "./workspace"),
72+
help="whether check model graph redundancy",
73+
)
6974
args = parser.parse_args()
75+
os.environ["GRAPH_NET_EXTRACT_WORKSPACE"] = args.workspace
76+
7077
main(args=args)

tools/count_sample.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# 统计 samples 目录下, graph_net.json 文件的数量
2+
import os
3+
4+
graph_net_count = 0
5+
for root, dirs, files in os.walk("../samples"):
6+
for file in files:
7+
if file == "graph_net.json":
8+
graph_net_count += 1
9+
print(f"Number of graph_net.json files: {graph_net_count}")

0 commit comments

Comments
 (0)