Skip to content

Commit 1ec86b3

Browse files
authored
Upgrade AllenNLP + hotfix (#287)
* Fix AllenNLP inference * Fix API GA workflow * Fix dataset name of batch test * Fix the docker GA workflow * Ensure AllenNLP test is actually running and not just skipped * Change env variable * Fix style, upgrade allennlp, revert some changes * Fix dependencies of AllenNLP * Test longer retry * Format * Test again
1 parent a06f59a commit 1ec86b3

File tree

14 files changed

+55
-345
lines changed

14 files changed

+55
-345
lines changed

.github/workflows/python-api-allennlp.yaml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,15 @@ on:
55
paths:
66
- "api-inference-community/docker_images/allennlp/**"
77
jobs:
8-
docker:
8+
build:
99
runs-on: ubuntu-latest
1010
steps:
1111
- name: Set up Python ${{ matrix.python-version }}
1212
uses: actions/setup-python@v2
1313
with:
1414
python-version: "3.8"
15+
- name: Checkout
16+
uses: actions/checkout@v2
1517
- name: Set up QEMU
1618
uses: docker/setup-qemu-action@v1
1719
- name: Set up Docker Buildx
@@ -22,5 +24,5 @@ jobs:
2224
pip install --upgrade pip
2325
pip install pytest pillow httpx
2426
pip install -e .
25-
- run: RUN_DOCKER=1 pytest -sv tests/test_dockers.py::DockerImageTests::test_allennlp
27+
- run: RUN_DOCKER_TESTS=1 pytest -sv tests/test_dockers.py::DockerImageTests::test_allennlp
2628
working-directory: api-inference-community

.github/workflows/python-api-tests.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ jobs:
1212

1313
steps:
1414
- run: |
15+
sudo apt-get update
1516
sudo apt-get install ffmpeg
1617
1718
- uses: actions/checkout@v2

api-inference-community/Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
.PHONY: quality style
22

33

4-
check_dirs := api_inference_community tests
4+
check_dirs := api_inference_community tests docker_images
55

66

77

api-inference-community/docker_images/allennlp/app/main.py

Lines changed: 34 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,20 @@
1+
import functools
12
import logging
23
import os
34
from typing import Dict, Type
45

6+
from api_inference_community.routes import pipeline_route, status_ok
57
from app.pipelines import Pipeline, QuestionAnsweringPipeline
6-
from app.routes import pipeline_route, status_ok
78
from starlette.applications import Starlette
9+
from starlette.middleware import Middleware
10+
from starlette.middleware.gzip import GZipMiddleware
811
from starlette.routing import Route
912

1013

14+
TASK = os.getenv("TASK")
15+
MODEL_ID = os.getenv("MODEL_ID")
16+
17+
1118
logger = logging.getLogger(__name__)
1219

1320

@@ -30,7 +37,10 @@
3037
}
3138

3239

33-
def get_pipeline(task: str, model_id: str) -> Pipeline:
40+
@functools.lru_cache()
41+
def get_pipeline() -> Pipeline:
42+
task = os.environ["TASK"]
43+
model_id = os.environ["MODEL_ID"]
3444
if task not in ALLOWED_TASKS:
3545
raise EnvironmentError(f"{task} is not a valid pipeline for model : {model_id}")
3646
return ALLOWED_TASKS[task](model_id)
@@ -41,14 +51,21 @@ def get_pipeline(task: str, model_id: str) -> Pipeline:
4151
Route("/{whatever:path}", pipeline_route, methods=["POST"]),
4252
]
4353

44-
app = Starlette(routes=routes)
54+
middleware = [Middleware(GZipMiddleware, minimum_size=1000)]
4555
if os.environ.get("DEBUG", "") == "1":
4656
from starlette.middleware.cors import CORSMiddleware
4757

48-
app.add_middleware(
49-
CORSMiddleware, allow_origins=["*"], allow_headers=["*"], allow_methods=["*"]
58+
middleware.append(
59+
Middleware(
60+
CORSMiddleware,
61+
allow_origins=["*"],
62+
allow_headers=["*"],
63+
allow_methods=["*"],
64+
)
5065
)
5166

67+
app = Starlette(routes=routes, middleware=middleware)
68+
5269

5370
@app.on_event("startup")
5471
async def startup_event():
@@ -57,14 +74,18 @@ async def startup_event():
5774
handler.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(message)s"))
5875
logger.handlers = [handler]
5976

60-
task = os.environ["TASK"]
61-
model_id = os.environ["MODEL_ID"]
62-
63-
app.pipeline = get_pipeline(task, model_id)
77+
# Link between `api-inference-community` and framework code.
78+
app.get_pipeline = get_pipeline
79+
try:
80+
get_pipeline()
81+
except Exception:
82+
# We can fail so we can show exception later.
83+
pass
6484

6585

6686
if __name__ == "__main__":
67-
task = os.environ["TASK"]
68-
model_id = os.environ["MODEL_ID"]
69-
70-
get_pipeline(task, model_id)
87+
try:
88+
get_pipeline()
89+
except Exception:
90+
# We can fail so we can show exception later.
91+
pass

api-inference-community/docker_images/allennlp/app/pipelines/question_answering.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ def __init__(
1111
self,
1212
model_id: str,
1313
):
14-
self.predictor = Predictor.from_path(model_id)
14+
self.predictor = Predictor.from_path("hf://" + model_id)
1515

1616
def __call__(self, inputs: Dict[str, str]) -> Dict[str, Any]:
1717
"""

api-inference-community/docker_images/allennlp/app/routes.py

Lines changed: 0 additions & 112 deletions
This file was deleted.

0 commit comments

Comments
 (0)