diff --git a/docs/source/en/guides/inference.md b/docs/source/en/guides/inference.md index bb92f18d5c..36a93f049e 100644 --- a/docs/source/en/guides/inference.md +++ b/docs/source/en/guides/inference.md @@ -207,36 +207,36 @@ For more details, refer to the [Inference Providers pricing documentation](https [`InferenceClient`]'s goal is to provide the easiest interface to run inference on Hugging Face models, on any provider. It has a simple API that supports the most common tasks. Here is a table showing which providers support which tasks: -| Task | Black Forest Labs | Cerebras | Cohere | fal-ai | Featherless AI | Fireworks AI | HF Inference | Hyperbolic | Nebius AI Studio | Novita AI | Replicate | Sambanova | Together | -| --------------------------------------------------- | ----------------- | -------- | ------ | ------ | -------------- | ------------ | ------------ | ---------- | ---------------- | --------- | --------- | --------- | -------- | -| [`~InferenceClient.audio_classification`] | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | -| [`~InferenceClient.audio_to_audio`] | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | -| [`~InferenceClient.automatic_speech_recognition`] | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | -| [`~InferenceClient.chat_completion`] | ❌ | ✅ | ✅ | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | -| [`~InferenceClient.document_question_answering`] | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | -| [`~InferenceClient.feature_extraction`] | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ | ✅ | ❌ | -| [`~InferenceClient.fill_mask`] | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | -| [`~InferenceClient.image_classification`] | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | -| [`~InferenceClient.image_segmentation`] | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | -| [`~InferenceClient.image_to_image`] | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | -| [`~InferenceClient.image_to_text`] | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | -| [`~InferenceClient.object_detection`] | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | -| [`~InferenceClient.question_answering`] | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | -| [`~InferenceClient.sentence_similarity`] | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | -| [`~InferenceClient.summarization`] | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | -| [`~InferenceClient.table_question_answering`] | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | -| [`~InferenceClient.text_classification`] | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | -| [`~InferenceClient.text_generation`] | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | -| [`~InferenceClient.text_to_image`] | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ✅ | -| [`~InferenceClient.text_to_speech`] | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | -| [`~InferenceClient.text_to_video`] | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | -| [`~InferenceClient.tabular_classification`] | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | -| [`~InferenceClient.tabular_regression`] | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | -| [`~InferenceClient.token_classification`] | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | -| [`~InferenceClient.translation`] | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | -| [`~InferenceClient.visual_question_answering`] | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | -| [`~InferenceClient.zero_shot_image_classification`] | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | -| [`~InferenceClient.zero_shot_classification`] | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | +| Task | Black Forest Labs | Cerebras | Cohere | fal-ai | Featherless AI | Fireworks AI | Groq | HF Inference | Hyperbolic | Nebius AI Studio | Novita AI | Replicate | Sambanova | Together | +| --------------------------------------------------- | ----------------- | -------- | ------ | ------ | -------------- | ------------ | ---- | ------------ | ---------- | ---------------- | --------- | --------- | --------- | -------- | +| [`~InferenceClient.audio_classification`] | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | +| [`~InferenceClient.audio_to_audio`] | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | +| [`~InferenceClient.automatic_speech_recognition`] | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | +| [`~InferenceClient.chat_completion`] | ❌ | ✅ | ✅ | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | +| [`~InferenceClient.document_question_answering`] | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | +| [`~InferenceClient.feature_extraction`] | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ | ✅ | ❌ | +| [`~InferenceClient.fill_mask`] | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | +| [`~InferenceClient.image_classification`] | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | +| [`~InferenceClient.image_segmentation`] | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | +| [`~InferenceClient.image_to_image`] | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | +| [`~InferenceClient.image_to_text`] | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | +| [`~InferenceClient.object_detection`] | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | | ❌ | +| [`~InferenceClient.question_answering`] | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | +| [`~InferenceClient.sentence_similarity`] | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | +| [`~InferenceClient.summarization`] | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | +| [`~InferenceClient.table_question_answering`] | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | +| [`~InferenceClient.text_classification`] | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | +| [`~InferenceClient.text_generation`] | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | +| [`~InferenceClient.text_to_image`] | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ✅ | +| [`~InferenceClient.text_to_speech`] | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | +| [`~InferenceClient.text_to_video`] | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | +| [`~InferenceClient.tabular_classification`] | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | +| [`~InferenceClient.tabular_regression`] | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | +| [`~InferenceClient.token_classification`] | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | +| [`~InferenceClient.translation`] | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | +| [`~InferenceClient.visual_question_answering`] | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | +| [`~InferenceClient.zero_shot_image_classification`] | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | +| [`~InferenceClient.zero_shot_classification`] | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | diff --git a/src/huggingface_hub/inference/_client.py b/src/huggingface_hub/inference/_client.py index df0d50a303..3439dafd89 100644 --- a/src/huggingface_hub/inference/_client.py +++ b/src/huggingface_hub/inference/_client.py @@ -134,7 +134,7 @@ class InferenceClient: path will be appended to the base URL (see the [TGI Messages API](https://huggingface.co/docs/text-generation-inference/en/messages_api) documentation for details). When passing a URL as `model`, the client will not append any suffix path to it. provider (`str`, *optional*): - Name of the provider to use for inference. Can be `"black-forest-labs"`, `"cerebras"`, `"cohere"`, `"fal-ai"`, `"featherless-ai"`, `"fireworks-ai"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"nscale"`, `"openai"`, `"replicate"`, "sambanova"` or `"together"`. + Name of the provider to use for inference. Can be `"black-forest-labs"`, `"cerebras"`, `"cohere"`, `"fal-ai"`, `"featherless-ai"`, `"fireworks-ai"`, `"groq"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"nscale"`, `"openai"`, `"replicate"`, "sambanova"` or `"together"`. Defaults to "auto" i.e. the first of the providers available for the model, sorted by the user's order in https://hf.co/settings/inference-providers. If model is a URL or `base_url` is passed, then `provider` is not used. token (`str`, *optional*): diff --git a/src/huggingface_hub/inference/_generated/_async_client.py b/src/huggingface_hub/inference/_generated/_async_client.py index d7ad020ebc..162d89369f 100644 --- a/src/huggingface_hub/inference/_generated/_async_client.py +++ b/src/huggingface_hub/inference/_generated/_async_client.py @@ -122,7 +122,7 @@ class AsyncInferenceClient: path will be appended to the base URL (see the [TGI Messages API](https://huggingface.co/docs/text-generation-inference/en/messages_api) documentation for details). When passing a URL as `model`, the client will not append any suffix path to it. provider (`str`, *optional*): - Name of the provider to use for inference. Can be `"black-forest-labs"`, `"cerebras"`, `"cohere"`, `"fal-ai"`, `"featherless-ai"`, `"fireworks-ai"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"nscale"`, `"openai"`, `"replicate"`, "sambanova"` or `"together"`. + Name of the provider to use for inference. Can be `"black-forest-labs"`, `"cerebras"`, `"cohere"`, `"fal-ai"`, `"featherless-ai"`, `"fireworks-ai"`, `"groq"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"nscale"`, `"openai"`, `"replicate"`, "sambanova"` or `"together"`. Defaults to "auto" i.e. the first of the providers available for the model, sorted by the user's order in https://hf.co/settings/inference-providers. If model is a URL or `base_url` is passed, then `provider` is not used. token (`str`, *optional*): diff --git a/src/huggingface_hub/inference/_providers/__init__.py b/src/huggingface_hub/inference/_providers/__init__.py index 9d1505c372..8d73b837fc 100644 --- a/src/huggingface_hub/inference/_providers/__init__.py +++ b/src/huggingface_hub/inference/_providers/__init__.py @@ -17,6 +17,7 @@ FalAITextToVideoTask, ) from .fireworks_ai import FireworksAIConversationalTask +from .groq import GroqConversationalTask from .hf_inference import ( HFInferenceBinaryInputTask, HFInferenceConversational, @@ -48,6 +49,7 @@ "fal-ai", "featherless-ai", "fireworks-ai", + "groq", "hf-inference", "hyperbolic", "nebius", @@ -84,6 +86,9 @@ "fireworks-ai": { "conversational": FireworksAIConversationalTask(), }, + "groq": { + "conversational": GroqConversationalTask(), + }, "hf-inference": { "text-to-image": HFInferenceTask("text-to-image"), "conversational": HFInferenceConversational(), diff --git a/src/huggingface_hub/inference/_providers/_common.py b/src/huggingface_hub/inference/_providers/_common.py index 96efcb4943..97b774a4f2 100644 --- a/src/huggingface_hub/inference/_providers/_common.py +++ b/src/huggingface_hub/inference/_providers/_common.py @@ -25,6 +25,7 @@ "cohere": {}, "fal-ai": {}, "fireworks-ai": {}, + "groq": {}, "hf-inference": {}, "hyperbolic": {}, "nebius": {}, diff --git a/src/huggingface_hub/inference/_providers/groq.py b/src/huggingface_hub/inference/_providers/groq.py new file mode 100644 index 0000000000..11e677504e --- /dev/null +++ b/src/huggingface_hub/inference/_providers/groq.py @@ -0,0 +1,9 @@ +from ._common import BaseConversationalTask + + +class GroqConversationalTask(BaseConversationalTask): + def __init__(self): + super().__init__(provider="groq", base_url="https://api.groq.com") + + def _prepare_route(self, mapped_model: str, api_key: str) -> str: + return "/openai/v1/chat/completions" diff --git a/tests/test_inference_providers.py b/tests/test_inference_providers.py index e4d1e468c1..9df3c3026c 100644 --- a/tests/test_inference_providers.py +++ b/tests/test_inference_providers.py @@ -29,6 +29,7 @@ FeatherlessTextGenerationTask, ) from huggingface_hub.inference._providers.fireworks_ai import FireworksAIConversationalTask +from huggingface_hub.inference._providers.groq import GroqConversationalTask from huggingface_hub.inference._providers.hf_inference import ( HFInferenceBinaryInputTask, HFInferenceConversational, @@ -436,6 +437,13 @@ def test_prepare_payload_as_dict(self): } +class TestGroqProvider: + def test_prepare_route(self): + """Test route preparation for Groq conversational task.""" + helper = GroqConversationalTask() + assert helper._prepare_route("username/repo_name", "hf_token") == "/openai/v1/chat/completions" + + class TestHFInferenceProvider: def test_prepare_mapping_info(self, mocker): helper = HFInferenceTask("text-classification")