Skip to content

Commit 8b54e42

Browse files
[Inference Providers] sambanova supports feature extraction (#3037)
* add feature extraction for sambanova * nit * update table * hf inference feature extraction * fix linter * fix
1 parent bebc1f7 commit 8b54e42

File tree

7 files changed

+70
-11
lines changed

7 files changed

+70
-11
lines changed

docs/source/en/guides/inference.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ For more details, refer to the [Inference Providers pricing documentation](https
191191
| [`~InferenceClient.automatic_speech_recognition`] |||||||||||||
192192
| [`~InferenceClient.chat_completion`] |||||||||||||
193193
| [`~InferenceClient.document_question_answering`] |||||||||||||
194-
| [`~InferenceClient.feature_extraction`] ||||||||||| ||
194+
| [`~InferenceClient.feature_extraction`] ||||||||||| ||
195195
| [`~InferenceClient.fill_mask`] |||||||||||||
196196
| [`~InferenceClient.image_classification`] |||||||||||||
197197
| [`~InferenceClient.image_segmentation`] |||||||||||||

src/huggingface_hub/inference/_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1080,7 +1080,7 @@ def feature_extraction(
10801080
)
10811081
response = self._inner_post(request_parameters)
10821082
np = _import_numpy()
1083-
return np.array(_bytes_to_dict(response), dtype="float32")
1083+
return np.array(provider_helper.get_response(response), dtype="float32")
10841084

10851085
def fill_mask(
10861086
self,

src/huggingface_hub/inference/_generated/_async_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1122,7 +1122,7 @@ async def feature_extraction(
11221122
)
11231123
response = await self._inner_post(request_parameters)
11241124
np = _import_numpy()
1125-
return np.array(_bytes_to_dict(response), dtype="float32")
1125+
return np.array(provider_helper.get_response(response), dtype="float32")
11261126

11271127
async def fill_mask(
11281128
self,

src/huggingface_hub/inference/_providers/__init__.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,18 @@
1313
FalAITextToVideoTask,
1414
)
1515
from .fireworks_ai import FireworksAIConversationalTask
16-
from .hf_inference import HFInferenceBinaryInputTask, HFInferenceConversational, HFInferenceTask
16+
from .hf_inference import (
17+
HFInferenceBinaryInputTask,
18+
HFInferenceConversational,
19+
HFInferenceFeatureExtractionTask,
20+
HFInferenceTask,
21+
)
1722
from .hyperbolic import HyperbolicTextGenerationTask, HyperbolicTextToImageTask
1823
from .nebius import NebiusConversationalTask, NebiusTextGenerationTask, NebiusTextToImageTask
1924
from .novita import NovitaConversationalTask, NovitaTextGenerationTask, NovitaTextToVideoTask
2025
from .openai import OpenAIConversationalTask
2126
from .replicate import ReplicateTask, ReplicateTextToSpeechTask
22-
from .sambanova import SambanovaConversationalTask
27+
from .sambanova import SambanovaConversationalTask, SambanovaFeatureExtractionTask
2328
from .together import TogetherConversationalTask, TogetherTextGenerationTask, TogetherTextToImageTask
2429

2530

@@ -72,7 +77,7 @@
7277
"audio-classification": HFInferenceBinaryInputTask("audio-classification"),
7378
"automatic-speech-recognition": HFInferenceBinaryInputTask("automatic-speech-recognition"),
7479
"fill-mask": HFInferenceTask("fill-mask"),
75-
"feature-extraction": HFInferenceTask("feature-extraction"),
80+
"feature-extraction": HFInferenceFeatureExtractionTask(),
7681
"image-classification": HFInferenceBinaryInputTask("image-classification"),
7782
"image-segmentation": HFInferenceBinaryInputTask("image-segmentation"),
7883
"document-question-answering": HFInferenceTask("document-question-answering"),
@@ -116,6 +121,7 @@
116121
},
117122
"sambanova": {
118123
"conversational": SambanovaConversationalTask(),
124+
"feature-extraction": SambanovaFeatureExtractionTask(),
119125
},
120126
"together": {
121127
"text-to-image": TogetherTextToImageTask(),

src/huggingface_hub/inference/_providers/hf_inference.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
import json
22
from functools import lru_cache
33
from pathlib import Path
4-
from typing import Any, Dict, Optional
4+
from typing import Any, Dict, Optional, Union
55

66
from huggingface_hub import constants
77
from huggingface_hub.hf_api import InferenceProviderMapping
8-
from huggingface_hub.inference._common import _b64_encode, _open_as_binary
8+
from huggingface_hub.inference._common import RequestParameters, _b64_encode, _bytes_to_dict, _open_as_binary
99
from huggingface_hub.inference._providers._common import TaskProviderHelper, filter_none
1010
from huggingface_hub.utils import build_hf_headers, get_session, get_token, hf_raise_for_status
1111

@@ -177,3 +177,13 @@ def _check_supported_task(model: str, task: str) -> None:
177177
f"Model '{model}' doesn't support task '{task}'. Supported tasks: '{pipeline_tag}', got: '{task}'"
178178
)
179179
return
180+
181+
182+
class HFInferenceFeatureExtractionTask(HFInferenceTask):
183+
def __init__(self):
184+
super().__init__("feature-extraction")
185+
186+
def get_response(self, response: Union[bytes, Dict], request_params: Optional[RequestParameters] = None) -> Any:
187+
if isinstance(response, bytes):
188+
return _bytes_to_dict(response)
189+
return response
Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,28 @@
1-
from huggingface_hub.inference._providers._common import BaseConversationalTask
1+
from typing import Any, Dict, Optional, Union
2+
3+
from huggingface_hub.hf_api import InferenceProviderMapping
4+
from huggingface_hub.inference._common import RequestParameters, _as_dict
5+
from huggingface_hub.inference._providers._common import BaseConversationalTask, TaskProviderHelper, filter_none
26

37

48
class SambanovaConversationalTask(BaseConversationalTask):
59
def __init__(self):
610
super().__init__(provider="sambanova", base_url="https://api.sambanova.ai")
11+
12+
13+
class SambanovaFeatureExtractionTask(TaskProviderHelper):
14+
def __init__(self):
15+
super().__init__(provider="sambanova", base_url="https://api.sambanova.ai", task="feature-extraction")
16+
17+
def _prepare_route(self, mapped_model: str, api_key: str) -> str:
18+
return "/v1/embeddings"
19+
20+
def _prepare_payload_as_dict(
21+
self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping
22+
) -> Optional[Dict]:
23+
parameters = filter_none(parameters)
24+
return {"input": inputs, "model": provider_mapping_info.provider_id, **parameters}
25+
26+
def get_response(self, response: Union[bytes, Dict], request_params: Optional[RequestParameters] = None) -> Any:
27+
embeddings = _as_dict(response)["data"]
28+
return [embedding["embedding"] for embedding in embeddings]

tests/test_inference_providers.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
from huggingface_hub.inference._providers.novita import NovitaConversationalTask, NovitaTextGenerationTask
3636
from huggingface_hub.inference._providers.openai import OpenAIConversationalTask
3737
from huggingface_hub.inference._providers.replicate import ReplicateTask, ReplicateTextToSpeechTask
38-
from huggingface_hub.inference._providers.sambanova import SambanovaConversationalTask
38+
from huggingface_hub.inference._providers.sambanova import SambanovaConversationalTask, SambanovaFeatureExtractionTask
3939
from huggingface_hub.inference._providers.together import TogetherTextToImageTask
4040

4141
from .testing_utils import assert_in_logs
@@ -903,13 +903,34 @@ def test_get_response_single_output(self, mocker):
903903

904904

905905
class TestSambanovaProvider:
906-
def test_prepare_url(self):
906+
def test_prepare_url_conversational(self):
907907
helper = SambanovaConversationalTask()
908908
assert (
909909
helper._prepare_url("sambanova_token", "username/repo_name")
910910
== "https://api.sambanova.ai/v1/chat/completions"
911911
)
912912

913+
def test_prepare_payload_as_dict_feature_extraction(self):
914+
helper = SambanovaFeatureExtractionTask()
915+
payload = helper._prepare_payload_as_dict(
916+
"Hello world",
917+
{"truncate": True},
918+
InferenceProviderMapping(
919+
hf_model_id="username/repo_name",
920+
providerId="provider-id",
921+
task="feature-extraction",
922+
status="live",
923+
),
924+
)
925+
assert payload == {"input": "Hello world", "model": "provider-id", "truncate": True}
926+
927+
def test_prepare_url_feature_extraction(self):
928+
helper = SambanovaFeatureExtractionTask()
929+
assert (
930+
helper._prepare_url("hf_token", "username/repo_name")
931+
== "https://router.huggingface.co/sambanova/v1/embeddings"
932+
)
933+
913934

914935
class TestTogetherProvider:
915936
def test_prepare_route_text_to_image(self):

0 commit comments

Comments
 (0)