Skip to content

Commit e85fb98

Browse files
committed
api inference compat response
Signed-off-by: Raphael Glon <[email protected]>
1 parent 2cedfbd commit e85fb98

File tree

3 files changed

+19
-2
lines changed

3 files changed

+19
-2
lines changed

src/huggingface_inference_toolkit/env_utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
import os
2+
3+
14
def strtobool(val: str) -> bool:
25
"""Convert a string representation of truth to True or False booleans.
36
True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values
@@ -20,3 +23,7 @@ def strtobool(val: str) -> bool:
2023
raise ValueError(
2124
f"Invalid truth value, it should be a string but {val} was provided instead."
2225
)
26+
27+
28+
def api_inference_compat():
29+
return strtobool(os.getenv("API_INFERENCE_COMPAT", "false"))

src/huggingface_inference_toolkit/sentence_transformers_utils.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
import importlib.util
2+
import os
23
from typing import Any, Dict, List, Tuple, Union
34

5+
from huggingface_inference_toolkit.env_utils import api_inference_compat
6+
47
try:
58
from typing import Literal
69
except ImportError:
@@ -26,7 +29,10 @@ def __call__(self, source_sentence: str, sentences: List[str]) -> Dict[str, floa
2629
embeddings1 = self.model.encode(source_sentence, convert_to_tensor=True)
2730
embeddings2 = self.model.encode(sentences, convert_to_tensor=True)
2831
similarities = util.pytorch_cos_sim(embeddings1, embeddings2).tolist()[0]
29-
return {"similarities": similarities}
32+
if api_inference_compat():
33+
return similarities
34+
else:
35+
return {"similarities": similarities}
3036

3137

3238
class SentenceEmbeddingPipeline:
@@ -36,7 +42,10 @@ def __init__(self, model_dir: str, device: Union[str, None] = None, **kwargs: An
3642

3743
def __call__(self, sentences: Union[str, List[str]]) -> Dict[str, List[float]]:
3844
embeddings = self.model.encode(sentences).tolist()
39-
return {"embeddings": embeddings}
45+
if api_inference_compat():
46+
return embeddings
47+
else:
48+
return {"embeddings": embeddings}
4049

4150

4251
class SentenceRankingPipeline:

src/huggingface_inference_toolkit/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import importlib.util
2+
import os
23
import sys
34
from pathlib import Path
45
from typing import Optional, Union

0 commit comments

Comments
 (0)