|
| 1 | +import os |
| 2 | + |
| 3 | +# just expand this list when adding new models: |
| 4 | +MODELOPTIONS = [ |
| 5 | + "full_human", |
| 6 | + "full_cat", |
| 7 | + "primate_face", |
| 8 | + "mouse_pupil_vclose", |
| 9 | + "horse_sideview", |
| 10 | + "full_macaque", |
| 11 | + "superanimal_mouse", |
| 12 | +] |
| 13 | + |
| 14 | + |
| 15 | +def get_dlclibrary_path(): |
| 16 | + """Get path of where dlclibrary (this repo) is currently running""" |
| 17 | + import importlib.util |
| 18 | + return os.path.split(importlib.util.find_spec("dlclibrary").origin)[0] |
| 19 | + |
| 20 | + |
| 21 | +def loadmodelnames(): |
| 22 | + """Load URLs and commits for available models""" |
| 23 | + from ruamel.yaml import YAML |
| 24 | + fn = os.path.join(get_dlclibrary_path(),"modelzoo_urls.yaml") |
| 25 | + with open(fn) as file: |
| 26 | + return YAML().load(file) |
| 27 | + |
| 28 | + |
| 29 | +def download_hugginface_model(modelname, target_dir,removeHFfolder=True): |
| 30 | + """ |
| 31 | + Downloads a DeepLabCut Model Zoo Project from Hugging Face |
| 32 | + """ |
| 33 | + from huggingface_hub import hf_hub_download |
| 34 | + import tarfile, os |
| 35 | + from pathlib import Path |
| 36 | + |
| 37 | + neturls = loadmodelnames() |
| 38 | + |
| 39 | + if modelname in neturls.keys(): |
| 40 | + print("Loading....", modelname) |
| 41 | + url = neturls[modelname].split("/") |
| 42 | + repo_id, targzfn = url[0] + "/" + url[1], str(url[-1]) |
| 43 | + |
| 44 | + hf_hub_download(repo_id, targzfn, cache_dir=str(target_dir)) |
| 45 | + # creates a new subfolder as indicated below, unzipping from there and deleting this folder |
| 46 | + |
| 47 | + # Building the HuggingFaceHub download path: |
| 48 | + hf_path = ( |
| 49 | + "models--" |
| 50 | + + url[0] |
| 51 | + + "--" |
| 52 | + + url[1] |
| 53 | + + "/snapshots/" |
| 54 | + + str(neturls[modelname + "_commit"]) |
| 55 | + + "/" |
| 56 | + + targzfn |
| 57 | + ) |
| 58 | + |
| 59 | + filename = os.path.join(target_dir, hf_path) |
| 60 | + with tarfile.open(filename, mode="r:gz") as tar: |
| 61 | + for member in tar: |
| 62 | + if not member.isdir(): |
| 63 | + fname = Path(member.name).name # getting the filename |
| 64 | + tar.makefile(member, target_dir + "/" + fname) |
| 65 | + # tar.extractall(target_dir, members=tarfilenamecutting(tar)) |
| 66 | + |
| 67 | + if removeHFfolder: |
| 68 | + # Removing folder |
| 69 | + import shutil |
| 70 | + shutil.rmtree( |
| 71 | + Path(os.path.join(target_dir, "models--" + url[0] + "--" + url[1])) |
| 72 | + ) |
| 73 | + |
| 74 | + else: |
| 75 | + models = [fn for fn in neturls.keys()] |
| 76 | + print("Model does not exist: ", modelname) |
| 77 | + print("Pick one of the following: ", MODELOPTIONS) |
| 78 | + |
| 79 | + |
| 80 | +if __name__ == "__main__": |
| 81 | + print("Randomly downloading a model for testing...") |
| 82 | + |
| 83 | + import random |
| 84 | + #modelname = 'full_cat' |
| 85 | + modelname = random.choice(MODELOPTIONS) |
| 86 | + |
| 87 | + target_dir = '/Users/alex/Downloads' # folder has to exist! |
| 88 | + download_hugginface_model(modelname, target_dir) |
0 commit comments