Skip to content

Commit c734d0c

Browse files
[Inference Providers] fold OpenAI support into provider parameter (#2949)
* add openai as a provider * fix
1 parent df6366c commit c734d0c

File tree

6 files changed

+42
-4
lines changed

6 files changed

+42
-4
lines changed

src/huggingface_hub/inference/_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ class InferenceClient:
133133
path will be appended to the base URL (see the [TGI Messages API](https://huggingface.co/docs/text-generation-inference/en/messages_api)
134134
documentation for details). When passing a URL as `model`, the client will not append any suffix path to it.
135135
provider (`str`, *optional*):
136-
Name of the provider to use for inference. Can be `"black-forest-labs"`, `"cerebras"`, `"cohere"`, `"fal-ai"`, `"fireworks-ai"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"replicate"`, "sambanova"` or `"together"`.
136+
Name of the provider to use for inference. Can be `"black-forest-labs"`, `"cerebras"`, `"cohere"`, `"fal-ai"`, `"fireworks-ai"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"openai"`, `"replicate"`, "sambanova"` or `"together"`.
137137
defaults to hf-inference (Hugging Face Serverless Inference API).
138138
If model is a URL or `base_url` is passed, then `provider` is not used.
139139
token (`str`, *optional*):

src/huggingface_hub/inference/_generated/_async_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ class AsyncInferenceClient:
121121
path will be appended to the base URL (see the [TGI Messages API](https://huggingface.co/docs/text-generation-inference/en/messages_api)
122122
documentation for details). When passing a URL as `model`, the client will not append any suffix path to it.
123123
provider (`str`, *optional*):
124-
Name of the provider to use for inference. Can be `"black-forest-labs"`, `"cerebras"`, `"cohere"`, `"fal-ai"`, `"fireworks-ai"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"replicate"`, "sambanova"` or `"together"`.
124+
Name of the provider to use for inference. Can be `"black-forest-labs"`, `"cerebras"`, `"cohere"`, `"fal-ai"`, `"fireworks-ai"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"openai"`, `"replicate"`, "sambanova"` or `"together"`.
125125
defaults to hf-inference (Hugging Face Serverless Inference API).
126126
If model is a URL or `base_url` is passed, then `provider` is not used.
127127
token (`str`, *optional*):

src/huggingface_hub/inference/_providers/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from .hyperbolic import HyperbolicTextGenerationTask, HyperbolicTextToImageTask
1616
from .nebius import NebiusConversationalTask, NebiusTextGenerationTask, NebiusTextToImageTask
1717
from .novita import NovitaConversationalTask, NovitaTextGenerationTask, NovitaTextToVideoTask
18+
from .openai import OpenAIConversationalTask
1819
from .replicate import ReplicateTask, ReplicateTextToSpeechTask
1920
from .sambanova import SambanovaConversationalTask
2021
from .together import TogetherConversationalTask, TogetherTextGenerationTask, TogetherTextToImageTask
@@ -30,6 +31,7 @@
3031
"hyperbolic",
3132
"nebius",
3233
"novita",
34+
"openai",
3335
"replicate",
3436
"sambanova",
3537
"together",
@@ -97,6 +99,9 @@
9799
"conversational": NovitaConversationalTask(),
98100
"text-to-video": NovitaTextToVideoTask(),
99101
},
102+
"openai": {
103+
"conversational": OpenAIConversationalTask(),
104+
},
100105
"replicate": {
101106
"text-to-image": ReplicateTask("text-to-image"),
102107
"text-to-speech": ReplicateTextToSpeechTask(),

src/huggingface_hub/inference/_providers/new_provider.md

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,11 @@ class MyNewProviderTaskProviderHelper(TaskProviderHelper):
2323
"""Define high-level parameters."""
2424
super().__init__(provider=..., base_url=..., task=...)
2525

26-
def get_response(self, response: Union[bytes, Dict]) -> Any:
26+
def get_response(
27+
self,
28+
response: Union[bytes, Dict],
29+
request_params: Optional[RequestParameters] = None,
30+
) -> Any:
2731
"""
2832
Return the response in the expected format.
2933
@@ -37,7 +41,7 @@ class MyNewProviderTaskProviderHelper(TaskProviderHelper):
3741
"""
3842
return super()._prepare_headers(headers, api_key)
3943

40-
def _prepare_route(self, mapped_model: str) -> str:
44+
def _prepare_route(self, mapped_model: str, api_key: str) -> str:
4145
"""Return the route to use for the request.
4246
4347
Override this method in subclasses for customized routes.
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
from typing import Optional
2+
3+
from huggingface_hub.inference._providers._common import BaseConversationalTask
4+
5+
6+
class OpenAIConversationalTask(BaseConversationalTask):
7+
def __init__(self):
8+
super().__init__(provider="openai", base_url="https://api.openai.com")
9+
10+
def _prepare_api_key(self, api_key: Optional[str]) -> str:
11+
if api_key is None:
12+
raise ValueError("You must provide an api_key to work with OpenAI API.")
13+
if api_key.startswith("hf_"):
14+
raise ValueError(
15+
"OpenAI provider is not available through Hugging Face routing, please use your own OpenAI API key."
16+
)
17+
return api_key
18+
19+
def _prepare_mapped_model(self, model: Optional[str]) -> str:
20+
if model is None:
21+
raise ValueError("Please provide an OpenAI model ID, e.g. `gpt-4o` or `o1`.")
22+
return model

tests/test_inference_providers.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
NovitaConversationalTask,
3838
NovitaTextGenerationTask,
3939
)
40+
from huggingface_hub.inference._providers.openai import OpenAIConversationalTask
4041
from huggingface_hub.inference._providers.replicate import ReplicateTask, ReplicateTextToSpeechTask
4142
from huggingface_hub.inference._providers.sambanova import SambanovaConversationalTask
4243
from huggingface_hub.inference._providers.together import (
@@ -707,6 +708,12 @@ def test_prepare_url_conversational(self):
707708
assert url == "https://api.novita.ai/v3/openai/chat/completions"
708709

709710

711+
class TestOpenAIProvider:
712+
def test_prepare_url(self):
713+
helper = OpenAIConversationalTask()
714+
assert helper._prepare_url("sk-XXXXXX", "gpt-4o-mini") == "https://api.openai.com/v1/chat/completions"
715+
716+
710717
class TestReplicateProvider:
711718
def test_prepare_headers(self):
712719
helper = ReplicateTask("text-to-image")

0 commit comments

Comments
 (0)