|
| 1 | +import argparse |
| 2 | +import json |
| 3 | +import os |
| 4 | +from pathlib import Path |
| 5 | +from urllib.request import urlopen, urlretrieve |
| 6 | + |
| 7 | + |
| 8 | +def retrieve_otx_model(data_dir, model_name, format="xml"): |
| 9 | + destination_folder = os.path.join(data_dir, "otx_models") |
| 10 | + os.makedirs(destination_folder, exist_ok=True) |
| 11 | + if format == "onnx": |
| 12 | + urlretrieve( |
| 13 | + f"https://storage.openvinotoolkit.org/repositories/model_api/test/otx_models/{model_name}/model.onnx", |
| 14 | + f"{destination_folder}/{model_name}.onnx", |
| 15 | + ) |
| 16 | + else: |
| 17 | + urlretrieve( |
| 18 | + f"https://storage.openvinotoolkit.org/repositories/model_api/test/otx_models/{model_name}/openvino.xml", |
| 19 | + f"{destination_folder}/{model_name}.xml", |
| 20 | + ) |
| 21 | + urlretrieve( |
| 22 | + f"https://storage.openvinotoolkit.org/repositories/model_api/test/otx_models/{model_name}/openvino.bin", |
| 23 | + f"{destination_folder}/{model_name}.bin", |
| 24 | + ) |
| 25 | + |
| 26 | + |
| 27 | +def prepare_model( |
| 28 | + data_dir="./data", |
| 29 | + public_scope=Path(__file__).resolve().parent / "public_scope.json", |
| 30 | +): |
| 31 | + # TODO refactor this test so that it does not use eval |
| 32 | + # flake8: noqa: F401 |
| 33 | + from model_api.models import ClassificationModel, DetectionModel, SegmentationModel |
| 34 | + |
| 35 | + with open(public_scope, "r") as f: |
| 36 | + public_scope = json.load(f) |
| 37 | + |
| 38 | + for model in public_scope: |
| 39 | + if model["name"].endswith(".xml") or model["name"].endswith(".onnx"): |
| 40 | + continue |
| 41 | + model = eval(model["type"]).create_model(model["name"], download_dir=data_dir) |
| 42 | + |
| 43 | + |
| 44 | +def prepare_data(data_dir="./data"): |
| 45 | + from io import BytesIO |
| 46 | + from zipfile import ZipFile |
| 47 | + |
| 48 | + COCO128_URL = "https://ultralytics.com/assets/coco128.zip" |
| 49 | + |
| 50 | + with urlopen(COCO128_URL) as zipresp: |
| 51 | + with ZipFile(BytesIO(zipresp.read())) as zfile: |
| 52 | + zfile.extractall(data_dir) |
| 53 | + |
| 54 | + urlretrieve( |
| 55 | + "https://raw.githubusercontent.com/Shenggan/BCCD_Dataset/master/BCCD/JPEGImages/BloodImage_00007.jpg", |
| 56 | + os.path.join(data_dir, "BloodImage_00007.jpg"), |
| 57 | + ) |
| 58 | + |
| 59 | + |
| 60 | +if __name__ == "__main__": |
| 61 | + parser = argparse.ArgumentParser(description="Data and model preparate script") |
| 62 | + parser.add_argument( |
| 63 | + "-d", |
| 64 | + dest="data_dir", |
| 65 | + default="./data", |
| 66 | + help="Directory to store downloaded models and datasets", |
| 67 | + ) |
| 68 | + parser.add_argument( |
| 69 | + "-p", |
| 70 | + dest="public_scope", |
| 71 | + default=Path(__file__).resolve().parent / "public_scope.json", |
| 72 | + help="JSON file with public model description", |
| 73 | + ) |
| 74 | + |
| 75 | + args = parser.parse_args() |
| 76 | + |
| 77 | + prepare_model(args.data_dir, args.public_scope) |
| 78 | + prepare_data(args.data_dir) |
| 79 | + retrieve_otx_model(args.data_dir, "mlc_mobilenetv3_large_voc") |
| 80 | + retrieve_otx_model(args.data_dir, "detection_model_with_xai_head") |
| 81 | + retrieve_otx_model(args.data_dir, "Lite-hrnet-18_mod2") |
| 82 | + retrieve_otx_model(args.data_dir, "tinynet_imagenet") |
| 83 | + retrieve_otx_model(args.data_dir, "cls_mobilenetv3_large_cars", "onnx") |
0 commit comments