Skip to content

Commit 0a6580a

Browse files
authored
[New Sample] Add Some TorchVision Graph (#395)
* add wide_resnet50_2 * add torchvision.(wide_resnet50_2, wide_resnet101_2) * pre-commit success * translate to En * add path --------- Co-authored-by: ADchampion3 <2427771853.com>
1 parent 5b16d63 commit 0a6580a

File tree

13 files changed

+23320
-0
lines changed

13 files changed

+23320
-0
lines changed
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
import argparse
2+
import os
3+
4+
import torch
5+
from torchvision import transforms
6+
7+
import graph_net
8+
9+
EXAMPLE_SAMPLE_REL_MODEL_PATHS = [
10+
"samples/torchvision/wide_resnet50_2",
11+
"samples/torchvision/wide_resnet101_2",
12+
]
13+
14+
15+
def extract_visio_graph(model_name: str, model_path: str):
16+
# Normalization parameters for ImageNet
17+
normalize = transforms.Normalize(
18+
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
19+
)
20+
21+
# Create dummy input
22+
batch_size = 1
23+
height, width = 224, 224 # Standard ImageNet size
24+
num_channels = 3
25+
random_input = torch.rand(batch_size, num_channels, height, width)
26+
normalized_input = normalize(random_input)
27+
28+
# download models using `torchvision.get_model`
29+
# all_models = list_models(module=torchvision.models)
30+
# if(model_path not in all_models):
31+
# print("Model not found")
32+
# return
33+
# model = torchvision.get_model(model_path, weights="DEFAULT")
34+
35+
# download models using torch.hub
36+
# Refer to https://docs.pytorch.org/docs/stable/hub.html
37+
torch.hub.set_dir(
38+
"../../../test"
39+
) # The default cache directory is $TORCH_HOME/hub; if the environment variable is not set, it defaults to ~/.cache
40+
endpoints = torch.hub.list("pytorch/vision")
41+
if model_path not in endpoints:
42+
print("Model not found")
43+
return
44+
model = torch.hub.load("pytorch/vision", model_path)
45+
46+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
47+
model.to(device)
48+
normalized_input = normalized_input.to(device)
49+
50+
model = graph_net.torch.extract(name=model_name, dynamic=True)(model)
51+
52+
print("Running inference...")
53+
print("Input shape:", normalized_input.shape)
54+
output = model(normalized_input)
55+
print("Inference finished. Output shape:", output.shape)
56+
57+
58+
if __name__ == "__main__":
59+
# get parameters from command line
60+
workspace_default = os.environ.get("GRAPH_NET_EXTRACT_WORKSPACE", "../../workspace")
61+
62+
parser = argparse.ArgumentParser()
63+
parser.add_argument(
64+
"--model_name", type=str, default="resnet18"
65+
) # Model name (customizable, recommended to be the same as the official name or an abbreviation)
66+
parser.add_argument(
67+
"--model_path", type=str, default="resnet18"
68+
) # Model name as defined on the official website
69+
parser.add_argument("--workspace", type=str, default=workspace_default)
70+
args = parser.parse_args()
71+
72+
os.environ["GRAPH_NET_EXTRACT_WORKSPACE"] = args.workspace
73+
74+
extract_visio_graph(args.model_name, args.model_path)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
3845c4a06416a471520e3402df093384cc395f177a979ccd37c32b4c9282ffba
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
{
2+
"framework": "torch",
3+
"num_devices_required": 1,
4+
"num_nodes_required": 1,
5+
"dynamic": true,
6+
"model_name": "wide_resnet101_2"
7+
}

samples/torchvision/wide_resnet101_2/input_meta.py

Whitespace-only changes.

samples/torchvision/wide_resnet101_2/input_tensor_constraints.py

Whitespace-only changes.

samples/torchvision/wide_resnet101_2/model.py

Lines changed: 7734 additions & 0 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)