Skip to content

Commit c3a98de

Browse files
committed
change model test files
1 parent ae110b5 commit c3a98de

File tree

1 file changed

+20
-3
lines changed

1 file changed

+20
-3
lines changed

graph_net/test/vision_model_test.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
from torchvision import transforms
77
import graph_net
88

9-
if __name__ == "__main__":
9+
10+
def extract_visio_graph(model_name, model_path):
1011
# Normalization parameters for ImageNet
1112
normalize = transforms.Normalize(
1213
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
@@ -20,16 +21,32 @@
2021
normalized_input = normalize(random_input)
2122

2223
# Instantiate model
23-
model = torchvision.models.get_model("resnet18", weights="DEFAULT")
24+
model = torchvision.models.get_model(model_path, weights="DEFAULT")
2425
model.eval()
2526

2627
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
2728
model.to(device)
2829
normalized_input = normalized_input.to(device)
2930

30-
model = graph_net.torch.extract(name="resnet18", dynamic=True)(model)
31+
model = graph_net.torch.extract(name=model_name, dynamic=True)(model)
3132

3233
print("Running inference...")
3334
print("Input shape:", normalized_input.shape)
3435
output = model(normalized_input)
3536
print("Inference finished. Output shape:", output.shape)
37+
38+
39+
if __name__ == "__main__":
40+
# get parameters from command line
41+
workspace_default = os.environ.get("GRAPH_NET_EXTRACT_WORKSPACE", "workspace")
42+
43+
parser = argparse.ArgumentParser()
44+
parser.add_argument("--model_name", type=str, default="resnet18")
45+
parser.add_argument("--model_path", type=str, default="resnet18") # timm 模型名称
46+
parser.add_argument("--workspace", type=str, default=workspace_default)
47+
parser.add_argument("--dynamic", type=bool, default=True)
48+
args = parser.parse_args()
49+
50+
os.environ["GRAPH_NET_EXTRACT_WORKSPACE"] = args.workspace
51+
52+
extract_visio_graph(args.model_name, args.model_path)

0 commit comments

Comments
 (0)