Skip to content

Commit 7081d81

Browse files
committed
add timm
1 parent 47b24ad commit 7081d81

File tree

1 file changed

+55
-0
lines changed

1 file changed

+55
-0
lines changed

graph_net/test/timm_model_test.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# pip install timm
2+
3+
import argparse
4+
import os
5+
import torch
6+
from torchvision import transforms
7+
import timm # 导入 timm 库
8+
import graph_net
9+
10+
os.environ["TIMMDL_DISABLE_RETRY"] = "1" # 禁用重试
11+
12+
13+
def extract_visio_graph(model_name, model_path, dynamic_mode):
14+
# Normalization parameters for ImageNet
15+
normalize = transforms.Normalize(
16+
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
17+
)
18+
19+
# Create dummy input
20+
batch_size = 1
21+
height, width = 224, 224 # Standard ImageNet size
22+
num_channels = 3
23+
random_input = torch.rand(batch_size, num_channels, height, width)
24+
normalized_input = normalize(random_input)
25+
26+
# Instantiate model using timm
27+
model = timm.create_model(model_path, pretrained=False) # 使用 timm 加载 resnet18
28+
model.eval()
29+
30+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31+
model.to(device)
32+
normalized_input = normalized_input.to(device)
33+
34+
# Extract graph structure
35+
model = graph_net.torch.extract(name=model_name, dynamic=dynamic_mode)(model)
36+
37+
print("Running inference...")
38+
output = model(normalized_input)
39+
print("Inference finished. Output shape:", output.shape)
40+
41+
42+
if __name__ == "__main__":
43+
# get parameters from command line
44+
workspace_default = os.environ.get("GRAPH_NET_EXTRACT_WORKSPACE", "workspace")
45+
46+
parser = argparse.ArgumentParser()
47+
parser.add_argument("--model_name", type=str, default="resnet18")
48+
parser.add_argument("--model_path", type=str, default="resnet18") # timm 模型名称
49+
parser.add_argument("--workspace", type=str, default=workspace_default)
50+
parser.add_argument("--dynamic", type=bool, default=True)
51+
args = parser.parse_args()
52+
53+
os.environ["GRAPH_NET_EXTRACT_WORKSPACE"] = args.workspace
54+
55+
extract_visio_graph(args.model_name, args.model_path, args.dynamic)

0 commit comments

Comments
 (0)