Skip to content

Commit 0103db4

Browse files
committed
fix(api inference): compat for text-classification token-classification
Signed-off-by: Raphael Glon <[email protected]>
1 parent 2a6e662 commit 0103db4

File tree

1 file changed

+17
-0
lines changed

1 file changed

+17
-0
lines changed

src/huggingface_inference_toolkit/handler.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from typing import Any, Dict, Literal, Optional, Union
44

55
from huggingface_inference_toolkit.const import HF_TRUST_REMOTE_CODE
6+
from huggingface_inference_toolkit.env_utils import api_inference_compat
67
from huggingface_inference_toolkit.sentence_transformers_utils import SENTENCE_TRANSFORMERS_TASKS
78
from huggingface_inference_toolkit.utils import (
89
check_and_register_custom_pipeline_from_directory,
@@ -101,6 +102,22 @@ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
101102
"or `candidateLabels`."
102103
)
103104

105+
if api_inference_compat() and self.pipeline.task in ["text-classification", "token-classification"] and \
106+
isinstance(inputs, str):
107+
inputs = [inputs]
108+
if self.pipeline.task == "text-classification" and "top_k" not in parameters:
109+
top_k = os.environ.get("DEFAULT_TOP_K", 5)
110+
parameters["top_k"] = top_k
111+
resp = self.pipeline(inputs, **parameters)
112+
# # We don't want to return {}
113+
if isinstance(resp, list) and len(resp) > 0:
114+
if not isinstance(resp[0], list):
115+
return [resp]
116+
else:
117+
return resp
118+
else:
119+
return resp
120+
104121
return (
105122
self.pipeline(**inputs, **parameters) if isinstance(inputs, dict) else self.pipeline(inputs, **parameters) # type: ignore
106123
)

0 commit comments

Comments
 (0)