Skip to content

Commit 963f728

Browse files
committed
zero shot classif: api inference compat
Signed-off-by: Raphael Glon <[email protected]>
1 parent a781375 commit 963f728

File tree

1 file changed

+21
-2
lines changed

1 file changed

+21
-2
lines changed

src/huggingface_inference_toolkit/handler.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def __init__(
2525
trust_remote_code=HF_TRUST_REMOTE_CODE,
2626
)
2727

28-
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
28+
def __call__(self, data: Dict[str, Any]):
2929
"""
3030
Handles an inference request with input data and makes a prediction.
3131
Args:
@@ -132,7 +132,7 @@ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
132132
if isinstance(inputs, list):
133133
if isinstance(resp, list) and len(resp) == len(inputs):
134134
for it in resp:
135-
# Batch size dim is the first it level, dicard it
135+
# Batch size dim is the first it level, discard it
136136
if isinstance(it, list) and len(it) == 1:
137137
new_resp.append(it[0])
138138
else:
@@ -160,6 +160,25 @@ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
160160
el["score"] = 1
161161
new_resp.append(el)
162162
resp = new_resp
163+
if self.pipeline.task == "zero-shot-classification":
164+
try:
165+
if isinstance(resp, dict):
166+
if 'labels' in resp and 'scores' in resp:
167+
labels = resp['labels']
168+
scores = resp['scores']
169+
if len(labels) == len(scores):
170+
new_resp = []
171+
for label, score in zip(labels, scores):
172+
new_resp.append({"label": label, "score": score})
173+
resp = new_resp
174+
else:
175+
raise Exception("labels and scores do not have the same len, {} != {}".format(
176+
len(labels), len(scores)))
177+
else:
178+
raise Exception("Missing labels or scores key in response dict {}".format(resp))
179+
except Exception as e:
180+
logging.logger.warning("Unable to remap response for api inference compat")
181+
logging.logger.exception(e)
163182
return resp
164183

165184

0 commit comments

Comments
 (0)