Skip to content

Commit 409f819

Browse files
authored
Inference API wrapper client (#65)
1 parent a88c772 commit 409f819

File tree

3 files changed

+224
-0
lines changed

3 files changed

+224
-0
lines changed

src/huggingface_hub/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,5 +32,6 @@
3232
from .file_download import cached_download, hf_hub_download, hf_hub_url
3333
from .hf_api import HfApi, HfFolder, repo_type_and_id_from_hf_id
3434
from .hub_mixin import ModelHubMixin
35+
from .inference_api import InferenceApi
3536
from .repository import Repository
3637
from .snapshot_download import snapshot_download
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
import logging
2+
from typing import Dict, List, Optional, Union
3+
4+
import requests
5+
6+
from .hf_api import HfApi
7+
8+
9+
logger = logging.getLogger(__name__)
10+
11+
12+
ENDPOINT = "https://api-inference.huggingface.co"
13+
14+
ALL_TASKS = [
15+
# NLP
16+
"text-classification",
17+
"token-classification",
18+
"table-question-answering",
19+
"question-answering",
20+
"zero-shot-classification",
21+
"translation",
22+
"summarization",
23+
"conversational",
24+
"feature-extraction",
25+
"text-generation",
26+
"text2text-generation",
27+
"fill-mask",
28+
"sentence-similarity",
29+
# Audio
30+
"text-to-speech",
31+
"automatic-speech-recognition",
32+
"audio-to-audio",
33+
"audio-source-separation",
34+
"voice-activity-detection",
35+
# Computer vision
36+
"image-classification",
37+
"object-detection",
38+
"image-segmentation",
39+
# Others
40+
"structured-data-classification",
41+
]
42+
43+
44+
class InferenceApi:
45+
"""Client to configure requests and make calls to the HuggingFace Inference API.
46+
47+
Example:
48+
49+
>>> from huggingface_hub.inference_api import InferenceApi
50+
51+
>>> # Mask-fill example
52+
>>> api = InferenceApi("bert-base-uncased")
53+
>>> api(inputs="The goal of life is [MASK].")
54+
>>> >> [{'sequence': 'the goal of life is life.', 'score': 0.10933292657136917, 'token': 2166, 'token_str': 'life'}]
55+
56+
>>> # Question Answering example
57+
>>> api = InferenceApi("deepset/roberta-base-squad2")
58+
>>> inputs = {"question":"What's my name?", "context":"My name is Clara and I live in Berkeley."}
59+
>>> api(inputs)
60+
>>> >> {'score': 0.9326569437980652, 'start': 11, 'end': 16, 'answer': 'Clara'}
61+
62+
>>> # Zero-shot example
63+
>>> api = InferenceApi("typeform/distilbert-base-uncased-mnli")
64+
>>> inputs = "Hi, I recently bought a device from your company but it is not working as advertised and I would like to get reimbursed!"
65+
>>> params = {"candidate_labels":["refund", "legal", "faq"]}
66+
>>> api(inputs, params)
67+
>>> >> {'sequence': 'Hi, I recently bought a device from your company but it is not working as advertised and I would like to get reimbursed!', 'labels': ['refund', 'faq', 'legal'], 'scores': [0.9378499388694763, 0.04914155602455139, 0.013008488342165947]}
68+
69+
>>> # Overriding configured task
70+
>>> api = InferenceApi("bert-base-uncased", task="feature-extraction")
71+
"""
72+
73+
def __init__(
74+
self,
75+
repo_id: str,
76+
task: Optional[str] = None,
77+
token: Optional[str] = None,
78+
gpu: Optional[bool] = False,
79+
):
80+
"""Inits headers and API call information.
81+
82+
Args:
83+
repo_id (``str``): Id of repository (e.g. `user/bert-base-uncased`).
84+
task (``str``, `optional`, defaults ``None``): Whether to force a task instead of using task specified in the repository.
85+
token (:obj:`str`, `optional`):
86+
The API token to use as HTTP bearer authorization. This is not the authentication token.
87+
You can find the token in https://huggingface.co/settings/token. Alternatively, you can
88+
find both your organizations and personal API tokens using `HfApi().whoami(token)`.
89+
gpu (``bool``, `optional`, defaults ``False``): Whether to use GPU instead of CPU for inference(requires Startup plan at least).
90+
.. note::
91+
Setting :obj:`token` is required when you want to use a private model.
92+
"""
93+
self.options = {"wait_for_model": True, "use_gpu": gpu}
94+
95+
self.headers = {}
96+
if isinstance(token, str):
97+
self.headers["Authorization"] = "Bearer {}".format(token)
98+
99+
# Configure task
100+
model_info = HfApi().model_info(repo_id=repo_id, token=token)
101+
if not model_info.pipeline_tag and not task:
102+
raise ValueError(
103+
"Task not specified in the repository. Please add it to the model card using pipeline_tag (https://huggingface.co/docs#how-is-a-models-type-of-inference-api-and-widget-determined)"
104+
)
105+
106+
if task and task != model_info.pipeline_tag:
107+
if task not in ALL_TASKS:
108+
raise ValueError(f"Invalid task {task}. Make sure it's valid.")
109+
110+
logger.warning(
111+
"You're using a different task than the one specified in the repository. Be sure to know what you're doing :)"
112+
)
113+
self.task = task
114+
else:
115+
self.task = model_info.pipeline_tag
116+
117+
self.api_url = f"{ENDPOINT}/pipeline/{self.task}/{repo_id}"
118+
119+
def __repr__(self):
120+
items = (f"{k}='{v}'" for k, v in self.__dict__.items())
121+
return f"{self.__class__.__name__}({', '.join(items)})"
122+
123+
def __call__(
124+
self,
125+
inputs: Union[str, Dict, List[str], List[List[str]]],
126+
params: Optional[Dict] = None,
127+
):
128+
payload = {
129+
"inputs": inputs,
130+
"options": self.options,
131+
}
132+
133+
if params:
134+
payload["parameters"] = params
135+
136+
# TODO: Decide if we should raise an error instead of
137+
# returning the json.
138+
response = requests.post(
139+
self.api_url, headers=self.headers, json=payload
140+
).json()
141+
return response

tests/test_inference_api.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
# Copyright 2020 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
import unittest
17+
18+
from huggingface_hub.inference_api import InferenceApi
19+
20+
from .testing_utils import with_production_testing
21+
22+
23+
class InferenceApiTest(unittest.TestCase):
24+
@with_production_testing
25+
def test_simple_inference(self):
26+
api = InferenceApi("bert-base-uncased")
27+
inputs = "Hi, I think [MASK] is cool"
28+
results = api(inputs)
29+
self.assertIsInstance(results, list)
30+
31+
result = results[0]
32+
self.assertIsInstance(result, dict)
33+
self.assertTrue("sequence" in result)
34+
self.assertTrue("score" in result)
35+
36+
@with_production_testing
37+
def test_inference_with_params(self):
38+
api = InferenceApi("typeform/distilbert-base-uncased-mnli")
39+
inputs = "I bought a device but it is not working and I would like to get reimbursed!"
40+
params = {"candidate_labels": ["refund", "legal", "faq"]}
41+
result = api(inputs, params)
42+
self.assertIsInstance(result, dict)
43+
self.assertTrue("sequence" in result)
44+
self.assertTrue("scores" in result)
45+
46+
@with_production_testing
47+
def test_inference_with_dict_inputs(self):
48+
api = InferenceApi("deepset/roberta-base-squad2")
49+
inputs = {
50+
"question": "What's my name?",
51+
"context": "My name is Clara and I live in Berkeley.",
52+
}
53+
result = api(inputs)
54+
self.assertIsInstance(result, dict)
55+
self.assertTrue("score" in result)
56+
self.assertTrue("answer" in result)
57+
58+
@with_production_testing
59+
def test_inference_overriding_task(self):
60+
api = InferenceApi(
61+
"sentence-transformers/paraphrase-albert-small-v2",
62+
task="feature-extraction",
63+
)
64+
inputs = "This is an example again"
65+
result = api(inputs)
66+
self.assertIsInstance(result, list)
67+
68+
@with_production_testing
69+
def test_inference_overriding_invalid_task(self):
70+
with self.assertRaises(
71+
ValueError, msg="Invalid task invalid-task. Make sure it's valid."
72+
):
73+
InferenceApi("bert-base-uncased", task="invalid-task")
74+
75+
@with_production_testing
76+
def test_inference_missing_input(self):
77+
api = InferenceApi("deepset/roberta-base-squad2")
78+
result = api({"question": "What's my name?"})
79+
self.assertIsInstance(result, dict)
80+
self.assertTrue("error" in result)
81+
self.assertTrue("warnings" in result)
82+
self.assertTrue(len(result["warnings"]) > 0)

0 commit comments

Comments
 (0)