-
Notifications
You must be signed in to change notification settings - Fork 9
Expand file tree
/
Copy pathprepare_model.py
More file actions
57 lines (43 loc) · 1.67 KB
/
prepare_model.py
File metadata and controls
57 lines (43 loc) · 1.67 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import argparse
import os
import sys
import urllib.request
MODEL_URL = "https://github.com/onnx/models/raw/main/validated/vision/classification/resnet/model/resnet50-v1-12.onnx"
MAX_TIMES_RETRY_DOWNLOAD = 5
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument("--input_model", type=str, required=False, default="resnet50-v1-12.onnx")
parser.add_argument("--output_model", type=str, required=True)
return parser.parse_args()
def progressbar(cur, total=100):
percent = "{:.2%}".format(cur / total)
sys.stdout.write("\r[%-100s] %s" % ("#" * int(cur), percent))
sys.stdout.flush()
def schedule(blocknum, blocksize, totalsize):
if totalsize == 0:
percent = 0
else:
percent = min(1.0, blocknum * blocksize / totalsize) * 100
progressbar(percent)
def download_model(url, model_name, retry_times=5):
if os.path.isfile(model_name):
print(f"{model_name} exists, skip download")
return True
print("download model...")
retries = 0
while retries < retry_times:
try:
urllib.request.urlretrieve(url, model_name, schedule)
break
except KeyboardInterrupt:
return False
except:
retries += 1
print(f"Download failed{', Retry downloading...' if retries < retry_times else '!'}")
return retries < retry_times
def prepare_model(input_model, output_model):
# Download model from [ONNX Model Zoo](https://github.com/onnx/models)
download_model(MODEL_URL, output_model, MAX_TIMES_RETRY_DOWNLOAD)
if __name__ == "__main__":
args = parse_arguments()
prepare_model(args.input_model, args.output_model)