Skip to content

Commit a099273

Browse files
[Inference Providers] check inference provider mapping for HF Inference API (#2948)
* check inference provider mapping for hf inference * custom check for supported tasks for hf inference * nit * comments * re-record cassettes * fix tests * Apply suggestion Co-authored-by: Lucain <[email protected]> * Update src/huggingface_hub/inference/_providers/hf_inference.py --------- Co-authored-by: Lucain <[email protected]>
1 parent 8efa31f commit a099273

File tree

59 files changed

+13772
-8654
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

59 files changed

+13772
-8654
lines changed

src/huggingface_hub/inference/_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1286,7 +1286,7 @@ def image_segmentation(
12861286
[ImageSegmentationOutputElement(score=0.989008, label='LABEL_184', mask=<PIL.PngImagePlugin.PngImageFile image mode=L size=400x300 at 0x7FDD2B129CC0>), ...]
12871287
```
12881288
"""
1289-
provider_helper = get_provider_helper(self.provider, task="audio-classification")
1289+
provider_helper = get_provider_helper(self.provider, task="image-segmentation")
12901290
request_parameters = provider_helper.prepare_request(
12911291
inputs=image,
12921292
parameters={

src/huggingface_hub/inference/_generated/_async_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1331,7 +1331,7 @@ async def image_segmentation(
13311331
[ImageSegmentationOutputElement(score=0.989008, label='LABEL_184', mask=<PIL.PngImagePlugin.PngImageFile image mode=L size=400x300 at 0x7FDD2B129CC0>), ...]
13321332
```
13331333
"""
1334-
provider_helper = get_provider_helper(self.provider, task="audio-classification")
1334+
provider_helper = get_provider_helper(self.provider, task="image-segmentation")
13351335
request_parameters = provider_helper.prepare_request(
13361336
inputs=image,
13371337
parameters={

src/huggingface_hub/inference/_providers/hf_inference.py

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,15 +24,16 @@ def _prepare_api_key(self, api_key: Optional[str]) -> str:
2424
return api_key or get_token() # type: ignore[return-value]
2525

2626
def _prepare_mapped_model(self, model: Optional[str]) -> str:
27-
if model is not None:
27+
if model is not None and model.startswith(("http://", "https://")):
2828
return model
29-
model = _fetch_recommended_models().get(self.task)
30-
if model is None:
29+
model_id = model if model is not None else _fetch_recommended_models().get(self.task)
30+
if model_id is None:
3131
raise ValueError(
3232
f"Task {self.task} has no recommended model for HF Inference. Please specify a model"
3333
" explicitly. Visit https://huggingface.co/tasks for more info."
3434
)
35-
return model
35+
_check_supported_task(model_id, self.task)
36+
return model_id
3637

3738
def _prepare_url(self, api_key: str, mapped_model: str) -> str:
3839
# hf-inference provider can handle URLs (e.g. Inference Endpoints or TGI deployment)
@@ -120,3 +121,39 @@ def _fetch_recommended_models() -> Dict[str, Optional[str]]:
120121
response = get_session().get(f"{constants.ENDPOINT}/api/tasks", headers=build_hf_headers())
121122
hf_raise_for_status(response)
122123
return {task: next(iter(details["widgetModels"]), None) for task, details in response.json().items()}
124+
125+
126+
@lru_cache(maxsize=None)
127+
def _check_supported_task(model: str, task: str) -> None:
128+
from huggingface_hub.hf_api import HfApi
129+
130+
model_info = HfApi().model_info(model)
131+
pipeline_tag = model_info.pipeline_tag
132+
tags = model_info.tags or []
133+
is_conversational = "conversational" in tags
134+
if task in ("text-generation", "conversational"):
135+
if pipeline_tag == "text-generation":
136+
# text-generation + conversational tag -> both tasks allowed
137+
if is_conversational:
138+
return
139+
# text-generation without conversational tag -> only text-generation allowed
140+
if task == "text-generation":
141+
return
142+
raise ValueError(f"Model '{model}' doesn't support task '{task}'.")
143+
144+
if pipeline_tag == "text2text-generation":
145+
if task == "text-generation":
146+
return
147+
raise ValueError(f"Model '{model}' doesn't support task '{task}'.")
148+
149+
if pipeline_tag == "image-text-to-text":
150+
if is_conversational and task == "conversational":
151+
return # Only conversational allowed if tagged as conversational
152+
raise ValueError("Non-conversational image-text-to-text task is not supported.")
153+
154+
# For all other tasks, just check pipeline tag
155+
if pipeline_tag != task:
156+
raise ValueError(
157+
f"Model '{model}' doesn't support task '{task}'. Supported tasks: '{pipeline_tag}', got: '{task}'"
158+
)
159+
return

tests/cassettes/TestInferenceClient.test_audio_classification[hf-inference,audio-classification].yaml

Lines changed: 1722 additions & 37 deletions
Large diffs are not rendered by default.

tests/cassettes/TestInferenceClient.test_audio_to_audio[hf-inference,audio-to-audio].yaml

Lines changed: 1724 additions & 3355 deletions
Large diffs are not rendered by default.

tests/cassettes/TestInferenceClient.test_automatic_speech_recognition[hf-inference,automatic-speech-recognition].yaml

Lines changed: 89 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,74 @@
11
interactions:
2+
- request:
3+
body: null
4+
headers:
5+
Accept:
6+
- '*/*'
7+
Accept-Encoding:
8+
- gzip, deflate
9+
Connection:
10+
- keep-alive
11+
X-Amzn-Trace-Id:
12+
- bb0168cc-7f1c-4d14-b411-41f64f5cc7b1
13+
method: GET
14+
uri: https://huggingface.co/api/models/jonatasgrosman/wav2vec2-large-xlsr-53-english
15+
response:
16+
body:
17+
string: '{"_id":"621ffdc136468d709f17cdb3","id":"jonatasgrosman/wav2vec2-large-xlsr-53-english","private":false,"pipeline_tag":"automatic-speech-recognition","library_name":"transformers","tags":["transformers","pytorch","jax","safetensors","wav2vec2","automatic-speech-recognition","audio","en","hf-asr-leaderboard","mozilla-foundation/common_voice_6_0","robust-speech-event","speech","xlsr-fine-tuning-week","dataset:common_voice","dataset:mozilla-foundation/common_voice_6_0","doi:10.57967/hf/3569","license:apache-2.0","model-index","endpoints_compatible","region:us"],"downloads":18234759,"likes":466,"modelId":"jonatasgrosman/wav2vec2-large-xlsr-53-english","author":"jonatasgrosman","sha":"569a6236e92bd5f7652a0420bfe9bb94c5664080","lastModified":"2023-03-25T10:56:55.000Z","gated":false,"disabled":false,"model-index":[{"name":"XLSR
18+
Wav2Vec2 English by Jonatas Grosman","results":[{"task":{"name":"Automatic
19+
Speech Recognition","type":"automatic-speech-recognition"},"dataset":{"name":"Common
20+
Voice en","type":"common_voice","args":"en"},"metrics":[{"name":"Test WER","type":"wer","value":19.06,"verified":false},{"name":"Test
21+
CER","type":"cer","value":7.69,"verified":false},{"name":"Test WER (+LM)","type":"wer","value":14.81,"verified":false},{"name":"Test
22+
CER (+LM)","type":"cer","value":6.84,"verified":false}]},{"task":{"name":"Automatic
23+
Speech Recognition","type":"automatic-speech-recognition"},"dataset":{"name":"Robust
24+
Speech Event - Dev Data","type":"speech-recognition-community-v2/dev_data","args":"en"},"metrics":[{"name":"Dev
25+
WER","type":"wer","value":27.72,"verified":false},{"name":"Dev CER","type":"cer","value":11.65,"verified":false},{"name":"Dev
26+
WER (+LM)","type":"wer","value":20.85,"verified":false},{"name":"Dev CER (+LM)","type":"cer","value":11.01,"verified":false}]}]}],"config":{"architectures":["Wav2Vec2ForCTC"],"model_type":"wav2vec2"},"cardData":{"language":"en","datasets":["common_voice","mozilla-foundation/common_voice_6_0"],"metrics":["wer","cer"],"tags":["audio","automatic-speech-recognition","en","hf-asr-leaderboard","mozilla-foundation/common_voice_6_0","robust-speech-event","speech","xlsr-fine-tuning-week"],"license":"apache-2.0","model-index":[{"name":"XLSR
27+
Wav2Vec2 English by Jonatas Grosman","results":[{"task":{"name":"Automatic
28+
Speech Recognition","type":"automatic-speech-recognition"},"dataset":{"name":"Common
29+
Voice en","type":"common_voice","args":"en"},"metrics":[{"name":"Test WER","type":"wer","value":19.06,"verified":false},{"name":"Test
30+
CER","type":"cer","value":7.69,"verified":false},{"name":"Test WER (+LM)","type":"wer","value":14.81,"verified":false},{"name":"Test
31+
CER (+LM)","type":"cer","value":6.84,"verified":false}]},{"task":{"name":"Automatic
32+
Speech Recognition","type":"automatic-speech-recognition"},"dataset":{"name":"Robust
33+
Speech Event - Dev Data","type":"speech-recognition-community-v2/dev_data","args":"en"},"metrics":[{"name":"Dev
34+
WER","type":"wer","value":27.72,"verified":false},{"name":"Dev CER","type":"cer","value":11.65,"verified":false},{"name":"Dev
35+
WER (+LM)","type":"wer","value":20.85,"verified":false},{"name":"Dev CER (+LM)","type":"cer","value":11.01,"verified":false}]}]}]},"transformersInfo":{"auto_model":"AutoModelForCTC","pipeline_tag":"automatic-speech-recognition","processor":"AutoProcessor"},"siblings":[{"rfilename":".gitattributes"},{"rfilename":"README.md"},{"rfilename":"alphabet.json"},{"rfilename":"config.json"},{"rfilename":"eval.py"},{"rfilename":"flax_model.msgpack"},{"rfilename":"full_eval.sh"},{"rfilename":"language_model/attrs.json"},{"rfilename":"language_model/lm.binary"},{"rfilename":"language_model/unigrams.txt"},{"rfilename":"log_mozilla-foundation_common_voice_6_0_en_test_predictions.txt"},{"rfilename":"log_mozilla-foundation_common_voice_6_0_en_test_predictions_greedy.txt"},{"rfilename":"log_mozilla-foundation_common_voice_6_0_en_test_targets.txt"},{"rfilename":"log_speech-recognition-community-v2_dev_data_en_validation_predictions.txt"},{"rfilename":"log_speech-recognition-community-v2_dev_data_en_validation_predictions_greedy.txt"},{"rfilename":"log_speech-recognition-community-v2_dev_data_en_validation_targets.txt"},{"rfilename":"model.safetensors"},{"rfilename":"mozilla-foundation_common_voice_6_0_en_test_eval_results.txt"},{"rfilename":"mozilla-foundation_common_voice_6_0_en_test_eval_results_greedy.txt"},{"rfilename":"preprocessor_config.json"},{"rfilename":"pytorch_model.bin"},{"rfilename":"special_tokens_map.json"},{"rfilename":"speech-recognition-community-v2_dev_data_en_validation_eval_results.txt"},{"rfilename":"speech-recognition-community-v2_dev_data_en_validation_eval_results_greedy.txt"},{"rfilename":"vocab.json"}],"spaces":["bertin-project/bertin-gpt-j-6B","Gradio-Blocks/Alexa-NLU-Clone","qanastek/Alexa-NLU-Clone","Gradio-Blocks/poor-mans-duplex","awacke1/ASR-High-Accuracy-Test","trysem/Spleeter_and_ASR","Detomo/audio-stream-translate","yashsrivastava/speech-to-text-yash","RealTimeLiveAIForHealth/ASR-High-Accuracy-Test","manmeetkaurbaxi/YouTube-Video-Summarizer","baaastien/Spleeter_and_ASR","GeekedReals/jonatasgrosman-wav2vec2-large-xlsr-53-english","Charles95/gradio-tasks","Gna1L/jonatasgrosman-wav2vec2-large-xlsr-53-english","Mintiny/Customer_Review_Audio_Analysis","awacke1/STT-TTS-ASR-AI-NLP-Pipeline","Detomo/audio-translate","Amrrs/yt-video-summarizer","JPLTedCas/TedCasSpeechRecognition","stanciu/jonatasgrosman-wav2vec2-large-xlsr-53-english","raunak627887/jonatasgrosman-wav2vec2-large-xlsr-53-english","Hrsh-Venket/Corrected-Speech-to-Text","Rhyolite/jonatasgrosman-wav2vec2-large-xlsr-53-english","Grepper/jonatasgrosman-wav2vec2-large-xlsr-53-english","Yarumo/jonatasgrosman-wav2vec2-large-xlsr-53-english","jbraun19/ASR-High-Accuracy-Test","sankalphimself/pitchpal","Rajab123/jonatasgrosman-wav2vec2-large-xlsr-53-english","SteeleN1/jonatasgrosman-wav2vec2-large-xlsr-53-english","melazab1/jonatasgrosman-wav2vec2-large-xlsr-53-english","amaamas/jonatasgrosman-wav2vec2-large-xlsr-53-english","Gearijigu/jonatasgrosman-wav2vec2-large-xlsr-53-english","sebasjm/jonatasgrosman-wav2vec2-large-xlsr-53-english","codetopolymath/jonatasgrosman-wav2vec2-large-xlsr-53-english","lingdai/jonatasgrosman-wav2vec2-large-xlsr-53-english","nabdtran/jonatasgrosman-wav2vec2-large-xlsr-53-english","NourAlmolhem/jonatasgrosman-wav2vec2-large-xlsr-53-english","leetik/jonatasgrosman-wav2vec2-large-xlsr-53-english","mastere00/jonatasgrosman-wav2vec2-large-xlsr-53-english","pragyachik/jonatasgrosman-wav2vec2-large-xlsr-53-english","shubhsnow/jonatasgrosman-wav2vec2-large-xlsr-53-english","kushiel/jonatasgrosman-wav2vec2-large-xlsr-53-english","adarsh8986/jonatasgrosman-wav2vec2-large-xlsr-53-english","Baghdad99/eng-to-hau","Akkaris/jonatasgrosman-wav2vec2-large-xlsr-53-english","Baghdad99/english-to-hausa","anonymous4me/jonatasgrosman-wav2vec2-large-xlsr-53-english","LEWOPO/Voice_to_Text","Nikhil0987/speechrecho","neridonk/jonatasgrosman-wav2vec2-large-xlsr-53-english","ganeshkamath89/gradio-huggingface-pipeline-tasks-demo-all","quangnhan145/jonatasgrosman-wav2vec2-large-xlsr-53-english-demo-app","niveone/jonatasgrosman-wav2vec2-large-xlsr-53-english","dincali/jonatasgrosman-wav2vec2-large-xlsr-53-english","oryxsoftware/speech-to-text","Shashwat2528/Avishkaarak-ekta-new-audio","AVISHKAARAM/avishkarak-ekta-audio","IES-Rafael-Alberti/AudioToText","Mahmoud2020220/jonatasgrosman-wav2vec2-large-xlsr-53-english","saronium/jonatasgrosman-wav2vec2-large-xlsr-53-english","trizzzy/jonatasgrosman-wav2vec2-large-xlsr-53-english","guna-entrans/jonatasgrosman-wav2vec2-large-xlsr-53-english","pmiguelpds/jonatasgrosman-wav2vec2-large-xlsr-53-english","LordCoffee/transcript","manohar025/video-summarizer","maan2605/Youtube_Video_Summarizer_using_ASR","ABIDFAYAZ/meeting-transcription","derrideanlils/jonatasgrosman-wav2vec2-large-xlsr-53-english","saud-altuwaijri/demo4","Durganihantri/AI-Child-Behavior-Assessment","65rted6tfdjhgfjyrf/Gibberish-transcribr","aicodingfun/Alexa-NLU-Clone"],"createdAt":"2022-03-02T23:29:05.000Z","safetensors":{"parameters":{"F32":315472545},"total":315472545},"inference":"warm","usedStorage":7177619310}'
36+
headers:
37+
Access-Control-Allow-Origin:
38+
- https://huggingface.co
39+
Access-Control-Expose-Headers:
40+
- X-Repo-Commit,X-Request-Id,X-Error-Code,X-Error-Message,X-Total-Count,ETag,Link,Accept-Ranges,Content-Range,X-Xet-Access-Token,X-Xet-Token-Expiration,X-Xet-Refresh-Route,X-Xet-Cas-Url,X-Xet-Hash
41+
Connection:
42+
- keep-alive
43+
Content-Length:
44+
- '8113'
45+
Content-Type:
46+
- application/json; charset=utf-8
47+
Date:
48+
- Fri, 21 Mar 2025 11:29:30 GMT
49+
ETag:
50+
- W/"1fb1-kJSNr/PIoqRzIeGGFb80dxk+xKs"
51+
Referrer-Policy:
52+
- strict-origin-when-cross-origin
53+
Vary:
54+
- Origin
55+
Via:
56+
- 1.1 02ee9ebd8a83522edf11335f04975776.cloudfront.net (CloudFront)
57+
X-Amz-Cf-Id:
58+
- lUVe7zu2SXod2amCAy6NI4WliBYareXEaVPdf4NCasTTA5C5P_dj-Q==
59+
X-Amz-Cf-Pop:
60+
- CDG52-P4
61+
X-Cache:
62+
- Miss from cloudfront
63+
X-Powered-By:
64+
- huggingface-moon
65+
X-Request-Id:
66+
- Root=1-67dd4d9a-2b095a0137b6d23f0720dd7b;bb0168cc-7f1c-4d14-b411-41f64f5cc7b1
67+
cross-origin-opener-policy:
68+
- same-origin
69+
status:
70+
code: 200
71+
message: OK
272
- request:
373
body: !!binary |
474
ZkxhQwAAACIQABAAAAGOABkSA+gA8AABIqDeyGtH35+dVPygWCFAFS5GAwAAEgAAAAAAAAAAAAAA
@@ -1666,7 +1736,7 @@ interactions:
16661736
Content-Length:
16671737
- '94321'
16681738
X-Amzn-Trace-Id:
1669-
- 3d462b35-26a7-4c68-959b-f6a7d0301caf
1739+
- d3d22a44-b253-43a2-8f4f-6a20b5aee642
16701740
method: POST
16711741
uri: https://router.huggingface.co/hf-inference/models/jonatasgrosman/wav2vec2-large-xlsr-53-english
16721742
response:
@@ -1675,40 +1745,48 @@ interactions:
16751745
headers:
16761746
Access-Control-Allow-Origin:
16771747
- '*'
1678-
Access-Control-Expose-Headers:
1679-
- X-Repo-Commit,X-Request-Id,X-Error-Code,X-Error-Message,X-Total-Count,ETag,Link,Accept-Ranges,Content-Range,X-Xet-Access-Token,X-Xet-Token-Expiration,X-Xet-Refresh-Route,X-Xet-Cas-Url,X-Xet-Hash
16801748
Connection:
16811749
- keep-alive
1682-
Content-Length:
1683-
- '49'
16841750
Content-Type:
16851751
- application/json
16861752
Date:
1687-
- Thu, 13 Feb 2025 15:49:08 GMT
1753+
- Fri, 21 Mar 2025 11:29:31 GMT
16881754
Referrer-Policy:
16891755
- strict-origin-when-cross-origin
1756+
Transfer-Encoding:
1757+
- chunked
16901758
Via:
1691-
- 1.1 fba88aca5f6c32c13257a59071f94248.cloudfront.net (CloudFront)
1759+
- 1.1 6676a739f016238678e391e91007cc98.cloudfront.net (CloudFront)
16921760
X-Amz-Cf-Id:
1693-
- Ae6o7enRbofgvT3JVME57ywv8zStyC4fsspyhh_4QrpEvLi6ZYwLWA==
1761+
- hu8tduAfzg-IW6ZBolY67qPKbBnahkp3CYwuCBoJkR3u3Wn2BHygvQ==
16941762
X-Amz-Cf-Pop:
16951763
- CDG55-P3
16961764
X-Cache:
16971765
- Miss from cloudfront
16981766
X-Powered-By:
16991767
- huggingface-moon
1768+
X-Robots-Tag:
1769+
- none
17001770
access-control-allow-credentials:
17011771
- 'true'
1772+
access-control-expose-headers:
1773+
- x-compute-type, x-compute-time
17021774
cross-origin-opener-policy:
17031775
- same-origin
1776+
server:
1777+
- uvicorn
17041778
vary:
17051779
- Origin, Access-Control-Request-Method, Access-Control-Request-Headers
1780+
x-compute-audio-length:
1781+
- '4.65'
17061782
x-compute-time:
1707-
- '0.638'
1783+
- '0.712'
17081784
x-compute-type:
1709-
- cache
1785+
- cpu
1786+
x-inference-id:
1787+
- EUm9ZEn7NXWqjizXKR6bM
17101788
x-request-id:
1711-
- Root=1-67ae1474-43a45db64b62db6251089c0d;3d462b35-26a7-4c68-959b-f6a7d0301caf
1789+
- Root=1-67dd4d9a-28cb8d3e43da94102f4cd965;d3d22a44-b253-43a2-8f4f-6a20b5aee642
17121790
x-sha:
17131791
- 569a6236e92bd5f7652a0420bfe9bb94c5664080
17141792
status:

0 commit comments

Comments
 (0)