Skip to content

Commit dea8c4a

Browse files
authored
Generic pipeline is all you need (#371)
1 parent 8693b8e commit dea8c4a

15 files changed

+138
-77
lines changed
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# TODO: Merge this with allenNLP to have a single workflow for all docker images.
2+
name: generic-docker
3+
4+
on:
5+
pull_request:
6+
paths:
7+
- "api-inference-community/docker_images/generic/**"
8+
jobs:
9+
build:
10+
runs-on: ubuntu-latest
11+
steps:
12+
- name: Set up Python ${{ matrix.python-version }}
13+
uses: actions/setup-python@v2
14+
with:
15+
python-version: "3.8"
16+
- name: Checkout
17+
uses: actions/checkout@v2
18+
- name: Set up QEMU
19+
uses: docker/setup-qemu-action@v1
20+
- name: Set up Docker Buildx
21+
uses: docker/setup-buildx-action@v1
22+
- name: Install dependencies
23+
working-directory: api-inference-community
24+
run: |
25+
pip install --upgrade pip
26+
pip install pytest pillow httpx
27+
pip install -e .
28+
- run: RUN_DOCKER_TESTS=1 pytest -sv tests/test_dockers.py::DockerImageTests::test_generic
29+
working-directory: api-inference-community

api-inference-community/docker_images/generic/app/main.py

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import functools
22
import logging
33
import os
4-
from typing import Dict, Type
54

65
from api_inference_community.routes import pipeline_route, status_ok
76
from app.pipelines import Pipeline
@@ -18,24 +17,10 @@
1817
logger = logging.getLogger(__name__)
1918

2019

21-
ALLOWED_TASKS: Dict[str, Type[Pipeline]] = {
22-
"audio-to-audio": Pipeline,
23-
"automatic-speech-recognition": Pipeline,
24-
"feature-extraction": Pipeline,
25-
"image-classification": Pipeline,
26-
"structured-data-classification": Pipeline,
27-
"text-to-image": Pipeline,
28-
"token-classification": Pipeline,
29-
}
30-
31-
3220
@functools.lru_cache()
3321
def get_pipeline() -> Pipeline:
34-
task = os.environ["TASK"]
3522
model_id = os.environ["MODEL_ID"]
36-
if task not in ALLOWED_TASKS:
37-
raise EnvironmentError(f"{task} is not a valid pipeline for model : {model_id}")
38-
return ALLOWED_TASKS[task](model_id)
23+
return Pipeline(model_id)
3924

4025

4126
routes = [
Lines changed: 2 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,4 @@
1-
import os
21
from typing import Dict, List
3-
from unittest import TestCase, skipIf
4-
5-
from app.main import ALLOWED_TASKS, get_pipeline
62

73

84
# Must contain at least one example of each implemented pipeline
@@ -14,39 +10,8 @@
1410
# This is very slow the first time as fasttext model is large.
1511
"feature-extraction": ["osanseviero/fasttext_english"],
1612
"image-classification": ["osanseviero/fastai_cat_vs_dog"],
13+
"structured-data-classification": ["osanseviero/wine-quality"],
14+
"text-classification": ["osanseviero/fasttext_nearest"],
1715
"text-to-image": ["osanseviero/BigGAN-deep-128"],
1816
"token-classification": ["osanseviero/en_core_web_sm"],
19-
"structured-data-classification": ["osanseviero/wine-quality"],
20-
}
21-
22-
ALL_TASKS = {
23-
"audio-to-audio",
24-
"automatic-speech-recognition",
25-
"feature-extraction",
26-
"image-classification",
27-
"question-answering",
28-
"sentence-similarity",
29-
"structure-data-classification",
30-
"text-to-speech",
31-
"token-classification",
3217
}
33-
34-
35-
class PipelineTestCase(TestCase):
36-
@skipIf(
37-
os.path.dirname(os.path.dirname(__file__)).endswith("common"),
38-
"common is a special case",
39-
)
40-
def test_has_at_least_one_task_enabled(self):
41-
self.assertGreater(
42-
len(ALLOWED_TASKS.keys()), 0, "You need to implement at least one task"
43-
)
44-
45-
def test_unsupported_tasks(self):
46-
unsupported_tasks = ALL_TASKS - ALLOWED_TASKS.keys()
47-
for unsupported_task in unsupported_tasks:
48-
with self.subTest(msg=unsupported_task, task=unsupported_task):
49-
os.environ["TASK"] = unsupported_task
50-
os.environ["MODEL_ID"] = "XX"
51-
with self.assertRaises(EnvironmentError):
52-
get_pipeline()

api-inference-community/docker_images/generic/tests/test_api_audio_to_audio.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,13 @@
44
from unittest import TestCase, skipIf
55

66
from api_inference_community.validation import ffmpeg_read
7-
from app.main import ALLOWED_TASKS
87
from parameterized import parameterized_class
98
from starlette.testclient import TestClient
109
from tests.test_api import TESTABLE_MODELS
1110

1211

1312
@skipIf(
14-
"audio-to-audio" not in ALLOWED_TASKS,
13+
"audio-to-audio" not in TESTABLE_MODELS,
1514
"audio-to-audio not implemented",
1615
)
1716
@parameterized_class(

api-inference-community/docker_images/generic/tests/test_api_automatic_speech_recognition.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,13 @@
22
import os
33
from unittest import TestCase, skipIf
44

5-
from app.main import ALLOWED_TASKS
65
from parameterized import parameterized_class
76
from starlette.testclient import TestClient
87
from tests.test_api import TESTABLE_MODELS
98

109

1110
@skipIf(
12-
"automatic-speech-recognition" not in ALLOWED_TASKS,
11+
"automatic-speech-recognition" not in TESTABLE_MODELS,
1312
"automatic-speech-recognition not implemented",
1413
)
1514
@parameterized_class(

api-inference-community/docker_images/generic/tests/test_api_feature_extraction.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,13 @@
22
import os
33
from unittest import TestCase, skipIf
44

5-
from app.main import ALLOWED_TASKS
65
from parameterized import parameterized_class
76
from starlette.testclient import TestClient
87
from tests.test_api import TESTABLE_MODELS
98

109

1110
@skipIf(
12-
"feature-extraction" not in ALLOWED_TASKS,
11+
"feature-extraction" not in TESTABLE_MODELS,
1312
"feature-extraction not implemented",
1413
)
1514
@parameterized_class(

api-inference-community/docker_images/generic/tests/test_api_image_classification.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,13 @@
22
import os
33
from unittest import TestCase, skipIf
44

5-
from app.main import ALLOWED_TASKS
65
from parameterized import parameterized_class
76
from starlette.testclient import TestClient
87
from tests.test_api import TESTABLE_MODELS
98

109

1110
@skipIf(
12-
"image-classification" not in ALLOWED_TASKS,
11+
"image-classification" not in TESTABLE_MODELS,
1312
"image-classification not implemented",
1413
)
1514
@parameterized_class(

api-inference-community/docker_images/generic/tests/test_api_question_answering.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,13 @@
22
import os
33
from unittest import TestCase, skipIf
44

5-
from app.main import ALLOWED_TASKS
65
from starlette.testclient import TestClient
76
from tests.test_api import TESTABLE_MODELS
87

98

109
@skipIf(
11-
"question-answering" not in ALLOWED_TASKS,
12-
"question-answering not implemented",
10+
"text-to-speech" not in TESTABLE_MODELS,
11+
"text-to-speech not implemented",
1312
)
1413
class QuestionAnsweringTestCase(TestCase):
1514
def setUp(self):

api-inference-community/docker_images/generic/tests/test_api_sentence_similarity.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,12 @@
22
import os
33
from unittest import TestCase, skipIf
44

5-
from app.main import ALLOWED_TASKS
65
from starlette.testclient import TestClient
76
from tests.test_api import TESTABLE_MODELS
87

98

109
@skipIf(
11-
"sentence-similarity" not in ALLOWED_TASKS,
10+
"sentence-similarity" not in TESTABLE_MODELS,
1211
"sentence-similarity not implemented",
1312
)
1413
class SentenceSimilarityTestCase(TestCase):

api-inference-community/docker_images/generic/tests/test_api_speech_segmentation.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,13 @@
22
import os
33
from unittest import TestCase, skipIf
44

5-
from app.main import ALLOWED_TASKS
65
from starlette.testclient import TestClient
76
from tests.test_api import TESTABLE_MODELS
87

98

109
@skipIf(
11-
"speech-segmentation" not in ALLOWED_TASKS,
12-
"speech-segmentation not implemented",
10+
"text-to-speech" not in TESTABLE_MODELS,
11+
"text-to-speech not implemented",
1312
)
1413
class SpeechSegmentationTestCase(TestCase):
1514
def setUp(self):

0 commit comments

Comments
 (0)