Skip to content

Commit f0fae7d

Browse files
authored
Support fetching models from GitHub. Make cache expire. (#30)
- Unzip the compressed data file in model 244262 - Update cache key to include the month => cache is never more than a month out of date. - Support a github key in the YAML configuration that overwrites the ModelDB copy of a model with the GitHub copy.
1 parent ab447e5 commit f0fae7d

File tree

3 files changed

+50
-5
lines changed

3 files changed

+50
-5
lines changed

.github/workflows/nrn-modeldb-ci.yaml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ jobs:
8383
if: github.event_name == 'pull_request'
8484

8585
- name: Install dependencies and project
86+
id: install-deps
8687
run: |
8788
set
8889
# Set up Xvfb
@@ -93,15 +94,16 @@ jobs:
9394
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
9495
#Install project in editable mode
9596
python -m pip install -e .
97+
echo "::set-output name=date::$(date -u "+%Y%m")"
9698
9799
- name: Cache ModelDB models
98100
id: cache-models
99-
uses: actions/cache@v2
101+
uses: actions/cache@v3
100102
with:
101103
path: |
102104
cache
103105
modeldb/modeldb-meta.yaml
104-
key: models
106+
key: models-${{steps.install-deps.outputs.date}}
105107

106108
- name: Get ModelDB models
107109
if: steps.cache-models.outputs.cache-hit != 'true'

modeldb/modeldb-run.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1064,6 +1064,10 @@
10641064
223962:
10651065
skip: true
10661066
comment: takes too long, need to see how to reduce time
1067+
244262:
1068+
script:
1069+
- if [[ ! -f Iintra.dat ]]; then unzip Iintra.dat.zip; fi
1070+
- ls -l Iintra.dat
10671071
149174:
10681072
skip: true
10691073
comment: testing in Python + hardcoded NEURON

modeldb/modeldb.py

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,16 @@
44
import base64
55
import os
66
import requests
7+
import time
78
from .progressbar import ProgressBar
89
import yaml
910
from .data import Model
1011
from .config import *
1112
import traceback
1213
from pprint import pformat
1314

14-
def download_model(model_id):
15+
def download_model(arg_tuple):
16+
model_id, model_run_info = arg_tuple
1517
try:
1618
model_json = requests.get(MDB_MODEL_DOWNLOAD_URL.format(model_id=model_id)).json()
1719
model = Model(
@@ -31,7 +33,41 @@ def download_model(model_id):
3133
)
3234
with open(model_zip_uri, "wb+") as zipfile:
3335
zipfile.write(base64.standard_b64decode(url["file_content"]))
34-
except Exception as e: # noqa
36+
37+
if "github" in model_run_info:
38+
# This means we should try to replace the version of the model that
39+
# we downloaded from the ModelDB API just above with a version from
40+
# GitHub
41+
github = model_run_info["github"]
42+
if github == "default":
43+
suffix = ""
44+
elif github.startswith("pull/"):
45+
pr_number = int(github[5:])
46+
suffix = "/pull/{}/head".format(pr_number)
47+
else:
48+
raise Exception("Invalid value for github key: {}".format(github))
49+
github_url = "https://api.github.com/repos/ModelDBRepository/{model_id}/zipball{suffix}".format(
50+
model_id=model_id, suffix=suffix
51+
)
52+
# Replace the local file `model_zip_uri` with the zip file we
53+
# downloaded from `github_url`
54+
num_attempts = 3
55+
status_codes = []
56+
for _ in range(num_attempts):
57+
github_response = requests.get(github_url)
58+
status_codes.append(github_response.status_code)
59+
if github_response.status_code == requests.codes.ok:
60+
break
61+
time.sleep(5)
62+
else:
63+
raise Exception(
64+
"Failed to download {} with status codes {}".format(
65+
github_url, status_codes
66+
)
67+
)
68+
with open(model_zip_uri, "wb+") as zipfile:
69+
zipfile.write(github_response.content)
70+
except Exception as e: # noqa
3571
model = e
3672

3773
return model_id, model
@@ -64,7 +100,10 @@ def _download_models(self, model_list=None):
64100
os.mkdir(MODELS_ZIP_DIR)
65101
models = requests.get(MDB_NEURON_MODELS_URL).json() if model_list is None else model_list
66102
pool = multiprocessing.Pool()
67-
processed_models = pool.imap_unordered(download_model, models)
103+
processed_models = pool.imap_unordered(
104+
download_model,
105+
[(model_id, self._run_instr.get(model_id, {})) for model_id in models],
106+
)
68107
download_err = {}
69108
for model_id, model in ProgressBar.iter(processed_models, len(models)):
70109
if not isinstance(model, Exception):

0 commit comments

Comments
 (0)