Skip to content

Commit 590bffd

Browse files
committed
fix: token classification api-inference-compat
Signed-off-by: Raphael Glon <[email protected]>
1 parent 0103db4 commit 590bffd

File tree

1 file changed

+11
-15
lines changed

1 file changed

+11
-15
lines changed

src/huggingface_inference_toolkit/handler.py

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -102,22 +102,18 @@ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
102102
"or `candidateLabels`."
103103
)
104104

105-
if api_inference_compat() and self.pipeline.task in ["text-classification", "token-classification"] and \
106-
isinstance(inputs, str):
107-
inputs = [inputs]
108-
if self.pipeline.task == "text-classification" and "top_k" not in parameters:
109-
top_k = os.environ.get("DEFAULT_TOP_K", 5)
110-
parameters["top_k"] = top_k
111-
resp = self.pipeline(inputs, **parameters)
112-
# # We don't want to return {}
113-
if isinstance(resp, list) and len(resp) > 0:
114-
if not isinstance(resp[0], list):
115-
return [resp]
116-
else:
117-
return resp
118-
else:
105+
if api_inference_compat():
106+
if self.pipeline.task == "text-classification" and isinstance(inputs, str):
107+
inputs = [inputs]
108+
parameters.setdefault("top_k", os.environ.get("DEFAULT_TOP_K", 5))
109+
resp = self.pipeline(inputs, **parameters)
110+
# # We don't want to return {}
111+
if isinstance(resp, list) and len(resp) > 0:
112+
if not isinstance(resp[0], list):
113+
return [resp]
119114
return resp
120-
115+
if self.pipeline.task == "token-classification":
116+
parameters.setdefault("aggregation_strategy", os.environ.get("DEFAULT_AGGREGATION_STRATEGY", "simple"))
121117
return (
122118
self.pipeline(**inputs, **parameters) if isinstance(inputs, dict) else self.pipeline(inputs, **parameters) # type: ignore
123119
)

0 commit comments

Comments
 (0)