Skip to content

Commit 249da97

Browse files
committed
Only add text-generation and conversational task from feedback
1 parent ff21568 commit 249da97

File tree

4 files changed

+72
-153
lines changed

4 files changed

+72
-153
lines changed

docs/source/en/guides/inference.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -192,14 +192,14 @@ For more details, refer to the [Inference Providers pricing documentation](https
192192

193193
[`InferenceClient`]'s goal is to provide the easiest interface to run inference on Hugging Face models, on any provider. It has a simple API that supports the most common tasks. Here is a table showing which providers support which tasks:
194194

195-
| Task | Black Forest Labs | Cerebras | Clarifai | Cohere | fal-ai | Featherless AI | Fireworks AI | Groq | HF Inference | Hyperbolic | Nebius AI Studio | Novita AI | Nscale | OVHcloud | Public AI | Replicate | Sambanova | Scaleway | Together | Wavespeed | Zai |
195+
| Task | Black Forest Labs | Cerebras | Clarifai | Cohere | fal-ai | Featherless AI | Fireworks AI | Groq | HF Inference | Hyperbolic | Nebius AI Studio | Novita AI | Nscale | OVHcloud AI Endpoints | Public AI | Replicate | Sambanova | Scaleway | Together | Wavespeed | Zai |
196196
| --------------------------------------------------- | ----------------- | -------- | -------- | ------ | ------ | -------------- | ------------ | ---- | ------------ | ---------- | ---------------- | --------- | ------ | -------- | ---------- | --------- | --------- | --------- | -------- | --------- | ---- |
197197
| [`~InferenceClient.audio_classification`] ||||||||||||||||||||||
198198
| [`~InferenceClient.audio_to_audio`] ||||||||||||||||||||||
199-
| [`~InferenceClient.automatic_speech_recognition`] |||||||||||||| ||||||||
199+
| [`~InferenceClient.automatic_speech_recognition`] |||||||||||||| ||||||||
200200
| [`~InferenceClient.chat_completion`] ||||||||||||||||||||||
201201
| [`~InferenceClient.document_question_answering`] ||||||||||||||||||||||
202-
| [`~InferenceClient.feature_extraction`] |||||||||||||| ||||||||
202+
| [`~InferenceClient.feature_extraction`] |||||||||||||| ||||||||
203203
| [`~InferenceClient.fill_mask`] ||||||||||||||||||||||
204204
| [`~InferenceClient.image_classification`] ||||||||||||||||||||||
205205
| [`~InferenceClient.image_segmentation`] ||||||||||||||||||||||
@@ -212,8 +212,8 @@ For more details, refer to the [Inference Providers pricing documentation](https
212212
| [`~InferenceClient.summarization`] ||||||||||||||||||||||
213213
| [`~InferenceClient.table_question_answering`] ||||||||||||||||||||||
214214
| [`~InferenceClient.text_classification`] ||||||||||||||||||||||
215-
| [`~InferenceClient.text_generation`] |||||||||||||| ||||||||
216-
| [`~InferenceClient.text_to_image`] |||||||||||||| ||||||||
215+
| [`~InferenceClient.text_generation`] |||||||||||||| ||||||||
216+
| [`~InferenceClient.text_to_image`] |||||||||||||| ||||||||
217217
| [`~InferenceClient.text_to_speech`] ||||||||||||||||||||||
218218
| [`~InferenceClient.text_to_video`] ||||||||||||||||||||||
219219
| [`~InferenceClient.tabular_classification`] ||||||||||||||||||||||

src/huggingface_hub/inference/_providers/__init__.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
from .novita import NovitaConversationalTask, NovitaTextGenerationTask, NovitaTextToVideoTask
3939
from .nscale import NscaleConversationalTask, NscaleTextToImageTask
4040
from .openai import OpenAIConversationalTask
41-
from .ovhcloud import OVHcloudAIEndpointsAutomaticSpeechRecognitionTask, OVHcloudAIEndpointsConversationalTask, OVHcloudAIEndpointsFeatureExtractionTask, OVHcloudAIEndpointsTextToImageTask
41+
from .ovhcloud import OVHcloudAIEndpointsConversationalTask, OVHcloudAIEndpointsTextGenerationTask
4242
from .publicai import PublicAIConversationalTask
4343
from .replicate import ReplicateImageToImageTask, ReplicateTask, ReplicateTextToImageTask, ReplicateTextToSpeechTask
4444
from .sambanova import SambanovaConversationalTask, SambanovaFeatureExtractionTask
@@ -170,9 +170,7 @@
170170
},
171171
"ovhcloud": {
172172
"conversational": OVHcloudAIEndpointsConversationalTask(),
173-
"text-to-image": OVHcloudAIEndpointsTextToImageTask(),
174-
"feature-extraction": OVHcloudAIEndpointsFeatureExtractionTask(),
175-
"automatic-speech-recognition": OVHcloudAIEndpointsAutomaticSpeechRecognitionTask(),
173+
"text-generation": OVHcloudAIEndpointsTextGenerationTask(),
176174
},
177175
"publicai": {
178176
"conversational": PublicAIConversationalTask(),
Lines changed: 19 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -1,84 +1,34 @@
1-
import base64
2-
from abc import ABC
3-
from typing import Any, Dict, Optional, Union
1+
from typing import Any, Optional, Union
42

5-
from huggingface_hub.hf_api import InferenceProviderMapping
63
from huggingface_hub.inference._common import RequestParameters, _as_dict
7-
from huggingface_hub.inference._providers._common import (
8-
TaskProviderHelper,
9-
filter_none,
10-
)
4+
from huggingface_hub.inference._providers._common import BaseConversationalTask, BaseTextGenerationTask
5+
116

127
_PROVIDER = "ovhcloud"
138
_BASE_URL = "https://oai.endpoints.kepler.ai.cloud.ovh.net"
149

15-
class OVHcloudAIEndpointsTask(TaskProviderHelper, ABC):
16-
def __init__(self, task: str):
17-
super().__init__(provider=_PROVIDER, base_url=_BASE_URL, task=task)
18-
19-
def _prepare_route(self, mapped_model: str, api_key: str) -> str:
20-
if self.task == "text-to-image":
21-
return "/v1/images/generations"
22-
elif self.task == "conversational":
23-
return "/v1/chat/completions"
24-
elif self.task == "feature-extraction":
25-
return "/v1/embeddings"
26-
elif self.task == "automatic-speech-recognition":
27-
return "/v1/audio/transcriptions"
28-
raise ValueError(f"Unsupported task '{self.task}' for OVHcloud AI Endpoints.")
29-
30-
def _prepare_payload_as_dict(
31-
self, messages: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping
32-
) -> Optional[Dict]:
33-
return {"messages": messages, "model": provider_mapping_info.provider_id, **filter_none(parameters)}
34-
35-
36-
class OVHcloudAIEndpointsConversationalTask(OVHcloudAIEndpointsTask):
37-
def __init__(self):
38-
super().__init__("conversational")
39-
40-
def _prepare_payload_as_dict(
41-
self, messages: Any, parameters: dict, provider_mapping_info: InferenceProviderMapping
42-
) -> Optional[dict]:
43-
return super()._prepare_payload_as_dict(messages, parameters, provider_mapping_info)
44-
4510

46-
class OVHcloudAIEndpointsTextToImageTask(OVHcloudAIEndpointsTask):
11+
class OVHcloudAIEndpointsConversationalTask(BaseConversationalTask):
4712
def __init__(self):
48-
super().__init__("text-to-image")
13+
super().__init__(provider=_PROVIDER, base_url=_BASE_URL)
4914

50-
def _prepare_payload_as_dict(
51-
self, inputs: Any, parameters: dict, provider_mapping_info: InferenceProviderMapping
52-
) -> Optional[dict]:
53-
mapped_model = provider_mapping_info.provider_id
54-
return {"prompt": inputs, "model": mapped_model, **filter_none(parameters)}
15+
def _prepare_route(self, mapped_model: str, api_key: str) -> str:
16+
return "/v1/chat/completions"
5517

56-
def get_response(self, response: Union[bytes, dict], request_params: Optional[RequestParameters] = None) -> Any:
57-
response_dict = _as_dict(response)
58-
return base64.b64decode(response_dict["data"][0]["b64_json"])
59-
60-
class OVHcloudAIEndpointsFeatureExtractionTask(OVHcloudAIEndpointsTask):
61-
def __init__(self):
62-
super().__init__("feature-extraction")
6318

64-
def _prepare_payload_as_dict(
65-
self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping
66-
) -> Optional[Dict]:
67-
return {"input": inputs, "model": provider_mapping_info.provider_id, **filter_none(parameters)}
68-
69-
def get_response(self, response: Union[bytes, dict], request_params: Optional[RequestParameters] = None) -> Any:
70-
embeddings = _as_dict(response)["data"]
71-
return [embedding["embedding"] for embedding in embeddings]
72-
73-
class OVHcloudAIEndpointsAutomaticSpeechRecognitionTask(OVHcloudAIEndpointsTask):
19+
class OVHcloudAIEndpointsTextGenerationTask(BaseTextGenerationTask):
7420
def __init__(self):
75-
super().__init__("automatic-speech-recognition")
21+
super().__init__(provider=_PROVIDER, base_url=_BASE_URL)
7622

77-
def _prepare_payload_as_dict(
78-
self, inputs: Any, parameters: dict, provider_mapping_info: InferenceProviderMapping
79-
) -> Optional[dict]:
80-
return {"file": inputs, "model": provider_mapping_info.provider_id, **filter_none(parameters)}
23+
def _prepare_route(self, mapped_model: str, api_key: str) -> str:
24+
return "/v1/chat/completions"
8125

8226
def get_response(self, response: Union[bytes, dict], request_params: Optional[RequestParameters] = None) -> Any:
83-
response_dict = _as_dict(response)
84-
return response_dict["text"]
27+
output = _as_dict(response)["choices"][0]
28+
return {
29+
"generated_text": output["text"],
30+
"details": {
31+
"finish_reason": output.get("finish_reason"),
32+
"seed": output.get("seed"),
33+
},
34+
}

0 commit comments

Comments
 (0)