Skip to content

Commit b9c7c08

Browse files
committed
Only add text-generation and conversational task from feedback
1 parent 3490a10 commit b9c7c08

File tree

5 files changed

+69
-152
lines changed

5 files changed

+69
-152
lines changed

docs/source/en/guides/inference.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ For more details, refer to the [Inference Providers pricing documentation](https
192192

193193
[`InferenceClient`]'s goal is to provide the easiest interface to run inference on Hugging Face models, on any provider. It has a simple API that supports the most common tasks. Here is a table showing which providers support which tasks:
194194

195-
| Task | Black Forest Labs | Cerebras | Clarifai | Cohere | fal-ai | Featherless AI | Fireworks AI | Groq | HF Inference | Hyperbolic | Nebius AI Studio | Novita AI | Nscale | OVHcloud | Public AI | Replicate | Sambanova | Scaleway | Together | Wavespeed | Zai |
195+
| Task | Black Forest Labs | Cerebras | Clarifai | Cohere | fal-ai | Featherless AI | Fireworks AI | Groq | HF Inference | Hyperbolic | Nebius AI Studio | Novita AI | Nscale | OVHcloud AI Endpoints | Public AI | Replicate | Sambanova | Scaleway | Together | Wavespeed | Zai |
196196
| --------------------------------------------------- | ----------------- | -------- | -------- | ------ | ------ | -------------- | ------------ | ---- | ------------ | ---------- | ---------------- | --------- | ------ | -------- | ---------- | --------- | --------- | --------- | -------- | --------- | ---- |
197197
| [`~InferenceClient.audio_classification`] ||||||||||||||||||||||
198198
| [`~InferenceClient.audio_to_audio`] ||||||||||||||||||||||

src/huggingface_hub/inference/_providers/__init__.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
from .novita import NovitaConversationalTask, NovitaTextGenerationTask, NovitaTextToVideoTask
3939
from .nscale import NscaleConversationalTask, NscaleTextToImageTask
4040
from .openai import OpenAIConversationalTask
41-
from .ovhcloud import OVHcloudAIEndpointsAutomaticSpeechRecognitionTask, OVHcloudAIEndpointsConversationalTask, OVHcloudAIEndpointsFeatureExtractionTask, OVHcloudAIEndpointsTextToImageTask
41+
from .ovhcloud import OVHcloudAIEndpointsConversationalTask, OVHcloudAIEndpointsTextGenerationTask
4242
from .publicai import PublicAIConversationalTask
4343
from .replicate import (
4444
ReplicateAutomaticSpeechRecognitionTask,
@@ -176,9 +176,7 @@
176176
},
177177
"ovhcloud": {
178178
"conversational": OVHcloudAIEndpointsConversationalTask(),
179-
"text-to-image": OVHcloudAIEndpointsTextToImageTask(),
180-
"feature-extraction": OVHcloudAIEndpointsFeatureExtractionTask(),
181-
"automatic-speech-recognition": OVHcloudAIEndpointsAutomaticSpeechRecognitionTask(),
179+
"text-generation": OVHcloudAIEndpointsTextGenerationTask(),
182180
},
183181
"publicai": {
184182
"conversational": PublicAIConversationalTask(),
Lines changed: 19 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -1,84 +1,34 @@
1-
import base64
2-
from abc import ABC
3-
from typing import Any, Dict, Optional, Union
1+
from typing import Any, Optional, Union
42

5-
from huggingface_hub.hf_api import InferenceProviderMapping
63
from huggingface_hub.inference._common import RequestParameters, _as_dict
7-
from huggingface_hub.inference._providers._common import (
8-
TaskProviderHelper,
9-
filter_none,
10-
)
4+
from huggingface_hub.inference._providers._common import BaseConversationalTask, BaseTextGenerationTask
5+
116

127
_PROVIDER = "ovhcloud"
138
_BASE_URL = "https://oai.endpoints.kepler.ai.cloud.ovh.net"
149

15-
class OVHcloudAIEndpointsTask(TaskProviderHelper, ABC):
16-
def __init__(self, task: str):
17-
super().__init__(provider=_PROVIDER, base_url=_BASE_URL, task=task)
18-
19-
def _prepare_route(self, mapped_model: str, api_key: str) -> str:
20-
if self.task == "text-to-image":
21-
return "/v1/images/generations"
22-
elif self.task == "conversational":
23-
return "/v1/chat/completions"
24-
elif self.task == "feature-extraction":
25-
return "/v1/embeddings"
26-
elif self.task == "automatic-speech-recognition":
27-
return "/v1/audio/transcriptions"
28-
raise ValueError(f"Unsupported task '{self.task}' for OVHcloud AI Endpoints.")
29-
30-
def _prepare_payload_as_dict(
31-
self, messages: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping
32-
) -> Optional[Dict]:
33-
return {"messages": messages, "model": provider_mapping_info.provider_id, **filter_none(parameters)}
34-
35-
36-
class OVHcloudAIEndpointsConversationalTask(OVHcloudAIEndpointsTask):
37-
def __init__(self):
38-
super().__init__("conversational")
39-
40-
def _prepare_payload_as_dict(
41-
self, messages: Any, parameters: dict, provider_mapping_info: InferenceProviderMapping
42-
) -> Optional[dict]:
43-
return super()._prepare_payload_as_dict(messages, parameters, provider_mapping_info)
44-
4510

46-
class OVHcloudAIEndpointsTextToImageTask(OVHcloudAIEndpointsTask):
11+
class OVHcloudAIEndpointsConversationalTask(BaseConversationalTask):
4712
def __init__(self):
48-
super().__init__("text-to-image")
13+
super().__init__(provider=_PROVIDER, base_url=_BASE_URL)
4914

50-
def _prepare_payload_as_dict(
51-
self, inputs: Any, parameters: dict, provider_mapping_info: InferenceProviderMapping
52-
) -> Optional[dict]:
53-
mapped_model = provider_mapping_info.provider_id
54-
return {"prompt": inputs, "model": mapped_model, **filter_none(parameters)}
15+
def _prepare_route(self, mapped_model: str, api_key: str) -> str:
16+
return "/v1/chat/completions"
5517

56-
def get_response(self, response: Union[bytes, dict], request_params: Optional[RequestParameters] = None) -> Any:
57-
response_dict = _as_dict(response)
58-
return base64.b64decode(response_dict["data"][0]["b64_json"])
59-
60-
class OVHcloudAIEndpointsFeatureExtractionTask(OVHcloudAIEndpointsTask):
61-
def __init__(self):
62-
super().__init__("feature-extraction")
6318

64-
def _prepare_payload_as_dict(
65-
self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping
66-
) -> Optional[Dict]:
67-
return {"input": inputs, "model": provider_mapping_info.provider_id, **filter_none(parameters)}
68-
69-
def get_response(self, response: Union[bytes, dict], request_params: Optional[RequestParameters] = None) -> Any:
70-
embeddings = _as_dict(response)["data"]
71-
return [embedding["embedding"] for embedding in embeddings]
72-
73-
class OVHcloudAIEndpointsAutomaticSpeechRecognitionTask(OVHcloudAIEndpointsTask):
19+
class OVHcloudAIEndpointsTextGenerationTask(BaseTextGenerationTask):
7420
def __init__(self):
75-
super().__init__("automatic-speech-recognition")
21+
super().__init__(provider=_PROVIDER, base_url=_BASE_URL)
7622

77-
def _prepare_payload_as_dict(
78-
self, inputs: Any, parameters: dict, provider_mapping_info: InferenceProviderMapping
79-
) -> Optional[dict]:
80-
return {"file": inputs, "model": provider_mapping_info.provider_id, **filter_none(parameters)}
23+
def _prepare_route(self, mapped_model: str, api_key: str) -> str:
24+
return "/v1/chat/completions"
8125

8226
def get_response(self, response: Union[bytes, dict], request_params: Optional[RequestParameters] = None) -> Any:
83-
response_dict = _as_dict(response)
84-
return response_dict["text"]
27+
output = _as_dict(response)["choices"][0]
28+
return {
29+
"generated_text": output["text"],
30+
"details": {
31+
"finish_reason": output.get("finish_reason"),
32+
"seed": output.get("seed"),
33+
},
34+
}

tests/test_inference_client.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -118,10 +118,8 @@
118118
"conversational": "meta-llama/Llama-3.1-8B-Instruct",
119119
},
120120
"ovhcloud": {
121-
"automatic-speech-recognition": "openai/whisper-large-v3",
122121
"conversational": "meta-llama/Llama-3.1-8B-Instruct",
123-
"feature-extraction": "BAAI/bge-m3",
124-
"text-to-image": "stabilityai/stable-diffusion-xl-base-1.0",
122+
"text-generation": "meta-llama/Llama-3.1-8B-Instruct",
125123
},
126124
"replicate": {
127125
"text-to-image": "ByteDance/SDXL-Lightning",

tests/test_inference_providers.py

Lines changed: 46 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,10 @@
4646
from huggingface_hub.inference._providers.novita import NovitaConversationalTask, NovitaTextGenerationTask
4747
from huggingface_hub.inference._providers.nscale import NscaleConversationalTask, NscaleTextToImageTask
4848
from huggingface_hub.inference._providers.openai import OpenAIConversationalTask
49-
from huggingface_hub.inference._providers.ovhcloud import OVHcloudAIEndpointsAutomaticSpeechRecognitionTask, OVHcloudAIEndpointsConversationalTask, OVHcloudAIEndpointsFeatureExtractionTask, OVHcloudAIEndpointsTextToImageTask
49+
from huggingface_hub.inference._providers.ovhcloud import (
50+
OVHcloudAIEndpointsConversationalTask,
51+
OVHcloudAIEndpointsTextGenerationTask,
52+
)
5053
from huggingface_hub.inference._providers.publicai import PublicAIConversationalTask
5154
from huggingface_hub.inference._providers.replicate import (
5255
ReplicateAutomaticSpeechRecognitionTask,
@@ -1470,94 +1473,62 @@ def test_prepare_payload_as_dict(self):
14701473
"top_p": 1,
14711474
}
14721475

1473-
def test_prepare_url_feature_extraction(self):
1474-
helper = OVHcloudAIEndpointsFeatureExtractionTask()
1475-
assert (
1476-
helper._prepare_url("hf_token", "username/repo_name")
1477-
== "https://router.huggingface.co/ovhcloud/v1/embeddings"
1478-
)
1476+
def test_prepare_route_conversational(self):
1477+
helper = OVHcloudAIEndpointsConversationalTask()
1478+
assert helper._prepare_route("username/repo_name", "hf_token") == "/v1/chat/completions"
14791479

1480-
def test_prepare_payload_as_dict_feature_extraction(self):
1481-
helper = OVHcloudAIEndpointsFeatureExtractionTask()
1482-
payload = helper._prepare_payload_as_dict(
1483-
"Example text to embed",
1484-
{"truncate": True},
1485-
InferenceProviderMapping(
1486-
provider="ovhcloud",
1487-
hf_model_id="BAAI/bge-m3",
1488-
providerId="BGE-M3",
1489-
task="feature-extraction",
1490-
status="live",
1491-
),
1492-
)
1493-
assert payload == {"input": "Example text to embed", "model": "BGE-M3", "truncate": True}
1480+
def test_prepare_url_text_generation(self):
1481+
helper = OVHcloudAIEndpointsTextGenerationTask()
1482+
url = helper._prepare_url("hf_token", "username/repo_name")
1483+
assert url == "https://router.huggingface.co/ovhcloud/v1/chat/completions"
14941484

1495-
def test_prepare_url_text_to_image(self):
1496-
helper = OVHcloudAIEndpointsTextToImageTask()
1497-
assert (
1498-
helper._prepare_url("hf_token", "username/repo_name")
1499-
== "https://router.huggingface.co/ovhcloud/v1/images/generations"
1500-
)
1501-
15021485
url = helper._prepare_url("ovhcloud_token", "username/repo_name")
1503-
assert url == "https://oai.endpoints.kepler.ai.cloud.ovh.net/v1/images/generations"
1504-
1505-
def test_prepare_payload_as_dict_text_to_image(self):
1506-
helper = OVHcloudAIEndpointsTextToImageTask()
1507-
payload = helper._prepare_payload_as_dict(
1508-
inputs="a beautiful cat",
1509-
provider_mapping_info=InferenceProviderMapping(
1510-
provider="ovhcloud",
1511-
hf_model_id="stabilityai/stable-diffusion-xl-base-1.0",
1512-
providerId="stable-diffusion-xl-base-v10",
1513-
task="text-to-image",
1514-
status="live",
1515-
),
1516-
parameters={}
1517-
)
1518-
assert payload == {
1519-
"prompt": "a beautiful cat",
1520-
"model": "stable-diffusion-xl-base-v10",
1521-
}
1522-
1523-
def test_text_to_image_get_response(self):
1524-
helper = OVHcloudAIEndpointsTextToImageTask()
1525-
response = helper.get_response({"data": [{"b64_json": base64.b64encode(b"image_bytes").decode()}]})
1526-
assert response == b"image_bytes"
1486+
assert url == "https://oai.endpoints.kepler.ai.cloud.ovh.net/v1/chat/completions"
15271487

1528-
def test_prepare_url_automatic_speech_recognition(self):
1529-
helper = OVHcloudAIEndpointsAutomaticSpeechRecognitionTask()
1530-
assert (
1531-
helper._prepare_url("hf_token", "username/repo_name")
1532-
== "https://router.huggingface.co/ovhcloud/v1/audio/transcriptions"
1533-
)
1534-
1535-
url = helper._prepare_url("ovhcloud_token", "username/repo_name")
1536-
assert url == "https://oai.endpoints.kepler.ai.cloud.ovh.net/v1/audio/transcriptions"
1488+
def test_prepare_route_text_generation(self):
1489+
helper = OVHcloudAIEndpointsTextGenerationTask()
1490+
assert helper._prepare_route("username/repo_name", "hf_token") == "/v1/chat/completions"
15371491

1538-
def test_prepare_payload_as_dict_automatic_speech_recognition(self):
1539-
helper = OVHcloudAIEndpointsAutomaticSpeechRecognitionTask()
1540-
1492+
def test_prepare_payload_as_dict_text_generation(self):
1493+
helper = OVHcloudAIEndpointsTextGenerationTask()
15411494
payload = helper._prepare_payload_as_dict(
1542-
f"data:audio/mpeg;base64,{base64.b64encode(b'dummy_audio_data').decode()}",
1543-
{},
1495+
"Once upon a time",
1496+
{"temperature": 0.7, "max_tokens": 100},
15441497
InferenceProviderMapping(
15451498
provider="ovhcloud",
1546-
hf_model_id="openai/whisper-large-v3",
1547-
providerId="whisper-large-v3",
1548-
task="automatic-speech-recognition",
1499+
hf_model_id="meta-llama/Llama-3.1-8B-Instruct",
1500+
providerId="Llama-3.1-8B-Instruct",
1501+
task="text-generation",
15491502
status="live",
15501503
),
15511504
)
15521505
assert payload == {
1553-
"file": f"data:audio/mpeg;base64,{base64.b64encode(b'dummy_audio_data').decode()}",
1554-
"model": "whisper-large-v3",
1506+
"prompt": "Once upon a time",
1507+
"temperature": 0.7,
1508+
"max_tokens": 100,
1509+
"model": "Llama-3.1-8B-Instruct",
15551510
}
15561511

1557-
def test_automatic_speech_recognition_get_response(self):
1558-
helper = OVHcloudAIEndpointsAutomaticSpeechRecognitionTask()
1559-
response = helper.get_response({"text": "Hello world"})
1560-
assert response == "Hello world"
1512+
def test_text_generation_get_response(self):
1513+
helper = OVHcloudAIEndpointsTextGenerationTask()
1514+
response = helper.get_response(
1515+
{
1516+
"choices": [
1517+
{
1518+
"text": " there was a beautiful princess",
1519+
"finish_reason": "stop",
1520+
"seed": 42,
1521+
}
1522+
]
1523+
}
1524+
)
1525+
assert response == {
1526+
"generated_text": " there was a beautiful princess",
1527+
"details": {
1528+
"finish_reason": "stop",
1529+
"seed": 42,
1530+
},
1531+
}
15611532

15621533

15631534
class TestReplicateProvider:

0 commit comments

Comments
 (0)