Skip to content

Commit 7bee08c

Browse files
mogith-pnWauplin
andauthored
Add clarifai as Inference provider (#3424)
* added clarifai as provider * updated ClarifaiConversationalTask class * Fix route in Clarifai provider --------- Co-authored-by: Lucain Pouget <[email protected]>
1 parent 5b1a914 commit 7bee08c

File tree

7 files changed

+85
-37
lines changed

7 files changed

+85
-37
lines changed

docs/source/en/guides/inference.md

Lines changed: 35 additions & 35 deletions
Large diffs are not rendered by default.

src/huggingface_hub/inference/_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ class InferenceClient:
130130
Note: for better compatibility with OpenAI's client, `model` has been aliased as `base_url`. Those 2
131131
arguments are mutually exclusive. If a URL is passed as `model` or `base_url` for chat completion, the `(/v1)/chat/completions` suffix path will be appended to the URL.
132132
provider (`str`, *optional*):
133-
Name of the provider to use for inference. Can be `"black-forest-labs"`, `"cerebras"`, `"cohere"`, `"fal-ai"`, `"featherless-ai"`, `"fireworks-ai"`, `"groq"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"nscale"`, `"openai"`, `publicai`, `"replicate"`, `"sambanova"`, `"scaleway"`, `"together"` or `"zai-org"`.
133+
Name of the provider to use for inference. Can be `"black-forest-labs"`, `"cerebras"`, `"clarifai"`, `"cohere"`, `"fal-ai"`, `"featherless-ai"`, `"fireworks-ai"`, `"groq"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"nscale"`, `"openai"`, `publicai`, `"replicate"`, `"sambanova"`, `"scaleway"`, `"together"` or `"zai-org"`.
134134
Defaults to "auto" i.e. the first of the providers available for the model, sorted by the user's order in https://hf.co/settings/inference-providers.
135135
If model is a URL or `base_url` is passed, then `provider` is not used.
136136
token (`str`, *optional*):

src/huggingface_hub/inference/_generated/_async_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ class AsyncInferenceClient:
118118
Note: for better compatibility with OpenAI's client, `model` has been aliased as `base_url`. Those 2
119119
arguments are mutually exclusive. If a URL is passed as `model` or `base_url` for chat completion, the `(/v1)/chat/completions` suffix path will be appended to the URL.
120120
provider (`str`, *optional*):
121-
Name of the provider to use for inference. Can be `"black-forest-labs"`, `"cerebras"`, `"cohere"`, `"fal-ai"`, `"featherless-ai"`, `"fireworks-ai"`, `"groq"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"nscale"`, `"openai"`, `publicai`, `"replicate"`, `"sambanova"`, `"scaleway"`, `"together"` or `"zai-org"`.
121+
Name of the provider to use for inference. Can be `"black-forest-labs"`, `"cerebras"`, `"clarifai"`, `"cohere"`, `"fal-ai"`, `"featherless-ai"`, `"fireworks-ai"`, `"groq"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"nscale"`, `"openai"`, `publicai`, `"replicate"`, `"sambanova"`, `"scaleway"`, `"together"` or `"zai-org"`.
122122
Defaults to "auto" i.e. the first of the providers available for the model, sorted by the user's order in https://hf.co/settings/inference-providers.
123123
If model is a URL or `base_url` is passed, then `provider` is not used.
124124
token (`str`, *optional*):

src/huggingface_hub/inference/_providers/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from ._common import TaskProviderHelper, _fetch_inference_provider_mapping
1010
from .black_forest_labs import BlackForestLabsTextToImageTask
1111
from .cerebras import CerebrasConversationalTask
12+
from .clarifai import ClarifaiConversationalTask
1213
from .cohere import CohereConversationalTask
1314
from .fal_ai import (
1415
FalAIAutomaticSpeechRecognitionTask,
@@ -50,6 +51,7 @@
5051
PROVIDER_T = Literal[
5152
"black-forest-labs",
5253
"cerebras",
54+
"clarifai",
5355
"cohere",
5456
"fal-ai",
5557
"featherless-ai",
@@ -78,6 +80,9 @@
7880
"cerebras": {
7981
"conversational": CerebrasConversationalTask(),
8082
},
83+
"clarifai": {
84+
"conversational": ClarifaiConversationalTask(),
85+
},
8186
"cohere": {
8287
"conversational": CohereConversationalTask(),
8388
},

src/huggingface_hub/inference/_providers/_common.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
# status="live")
2525
"cerebras": {},
2626
"cohere": {},
27+
"clarifai": {},
2728
"fal-ai": {},
2829
"fireworks-ai": {},
2930
"groq": {},
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from ._common import BaseConversationalTask
2+
3+
4+
_PROVIDER = "clarifai"
5+
_BASE_URL = "https://api.clarifai.com"
6+
7+
8+
class ClarifaiConversationalTask(BaseConversationalTask):
9+
def __init__(self):
10+
super().__init__(provider=_PROVIDER, base_url=_BASE_URL)
11+
12+
def _prepare_route(self, mapped_model: str, api_key: str) -> str:
13+
return "/v2/ext/openai/v1/chat/completions"

tests/test_inference_providers.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
recursive_merge,
1818
)
1919
from huggingface_hub.inference._providers.black_forest_labs import BlackForestLabsTextToImageTask
20+
from huggingface_hub.inference._providers.clarifai import ClarifaiConversationalTask
2021
from huggingface_hub.inference._providers.cohere import CohereConversationalTask
2122
from huggingface_hub.inference._providers.fal_ai import (
2223
_POLLING_INTERVAL,
@@ -293,6 +294,34 @@ def test_prepare_payload_as_dict(self):
293294
}
294295

295296

297+
class TestClarifaiProvider:
298+
def test_prepare_url(self):
299+
helper = ClarifaiConversationalTask()
300+
assert (
301+
helper._prepare_url("clarifai_api_key", "username/repo_name")
302+
== "https://api.clarifai.com/v2/ext/openai/v1/chat/completions"
303+
)
304+
305+
def test_prepare_payload_as_dict(self):
306+
helper = ClarifaiConversationalTask()
307+
payload = helper._prepare_payload_as_dict(
308+
[{"role": "user", "content": "Hello!"}],
309+
{},
310+
InferenceProviderMapping(
311+
provider="clarifai",
312+
hf_model_id="meta-llama/llama-3.1-8B-Instruct",
313+
providerId="meta-llama/llama-3.1-8B-Instruct",
314+
task="conversational",
315+
status="live",
316+
),
317+
)
318+
319+
assert payload == {
320+
"messages": [{"role": "user", "content": "Hello!"}],
321+
"model": "meta-llama/llama-3.1-8B-Instruct",
322+
}
323+
324+
296325
class TestFalAIProvider:
297326
def test_prepare_headers_fal_ai_key(self):
298327
"""When using direct call, must use Key authorization."""

0 commit comments

Comments
 (0)