@@ -20,16 +20,18 @@ def extract_visio_graph(model_name: str, model_path: str):
2020 random_input = torch .rand (batch_size , num_channels , height , width )
2121 normalized_input = normalize (random_input )
2222
23- # 使用get_model下载模型
23+ # download models using `torchvision.get_model`
2424 # all_models = list_models(module=torchvision.models)
2525 # if(model_path not in all_models):
26- # print("不存在该模型, 请校验模型名称是否相同 ")
26+ # print("Model not found ")
2727 # return
28- # model = get_model(model_path, weights="DEFAULT")
28+ # model = torchvision. get_model(model_path, weights="DEFAULT")
2929
30- # 使用torch.hub下载模型
31- # 相关使用办法见https://docs.pytorch.org/docs/stable/hub.html
32- torch .hub .set_dir ("../../../test" ) # 缓存目录默认为$TORCH_HOME/hub 如果没有设置环境变量则为 ~/.cache
30+ # download models using torch.hub
31+ # Refer to https://docs.pytorch.org/docs/stable/hub.html
32+ torch .hub .set_dir (
33+ "../../../test"
34+ ) # The default cache directory is $TORCH_HOME/hub; if the environment variable is not set, it defaults to ~/.cache
3335 endpoints = torch .hub .list ("pytorch/vision" )
3436 if model_path not in endpoints :
3537 print ("Model not found" )
@@ -55,8 +57,10 @@ def extract_visio_graph(model_name: str, model_path: str):
5557 parser = argparse .ArgumentParser ()
5658 parser .add_argument (
5759 "--model_name" , type = str , default = "resnet18"
58- ) # 模型名称(自定义,推荐与官网相同或者简写)
59- parser .add_argument ("--model_path" , type = str , default = "resnet18" ) # 官网定义模型名称
60+ ) # Model name (customizable, recommended to be the same as the official name or an abbreviation)
61+ parser .add_argument (
62+ "--model_path" , type = str , default = "resnet18"
63+ ) # Model name as defined on the official website
6064 parser .add_argument ("--workspace" , type = str , default = workspace_default )
6165 args = parser .parse_args ()
6266
0 commit comments