Skip to content

Commit 7728c14

Browse files
author
ADchampion3
committed
translate to En
1 parent df8b852 commit 7728c14

File tree

1 file changed

+12
-8
lines changed

1 file changed

+12
-8
lines changed

graph_net/test/vision_model_extract_test.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,18 @@ def extract_visio_graph(model_name: str, model_path: str):
2020
random_input = torch.rand(batch_size, num_channels, height, width)
2121
normalized_input = normalize(random_input)
2222

23-
# 使用get_model下载模型
23+
# download models using `torchvision.get_model`
2424
# all_models = list_models(module=torchvision.models)
2525
# if(model_path not in all_models):
26-
# print("不存在该模型, 请校验模型名称是否相同")
26+
# print("Model not found")
2727
# return
28-
# model = get_model(model_path, weights="DEFAULT")
28+
# model = torchvision.get_model(model_path, weights="DEFAULT")
2929

30-
# 使用torch.hub下载模型
31-
# 相关使用办法见https://docs.pytorch.org/docs/stable/hub.html
32-
torch.hub.set_dir("../../../test") # 缓存目录默认为$TORCH_HOME/hub 如果没有设置环境变量则为 ~/.cache
30+
# download models using torch.hub
31+
# Refer to https://docs.pytorch.org/docs/stable/hub.html
32+
torch.hub.set_dir(
33+
"../../../test"
34+
) # The default cache directory is $TORCH_HOME/hub; if the environment variable is not set, it defaults to ~/.cache
3335
endpoints = torch.hub.list("pytorch/vision")
3436
if model_path not in endpoints:
3537
print("Model not found")
@@ -55,8 +57,10 @@ def extract_visio_graph(model_name: str, model_path: str):
5557
parser = argparse.ArgumentParser()
5658
parser.add_argument(
5759
"--model_name", type=str, default="resnet18"
58-
) # 模型名称(自定义,推荐与官网相同或者简写)
59-
parser.add_argument("--model_path", type=str, default="resnet18") # 官网定义模型名称
60+
) # Model name (customizable, recommended to be the same as the official name or an abbreviation)
61+
parser.add_argument(
62+
"--model_path", type=str, default="resnet18"
63+
) # Model name as defined on the official website
6064
parser.add_argument("--workspace", type=str, default=workspace_default)
6165
args = parser.parse_args()
6266

0 commit comments

Comments
 (0)