@@ -25,7 +25,7 @@ def __init__(
2525 trust_remote_code = HF_TRUST_REMOTE_CODE ,
2626 )
2727
28- def __call__ (self , data : Dict [str , Any ]) -> Dict [ str , Any ] :
28+ def __call__ (self , data : Dict [str , Any ]):
2929 """
3030 Handles an inference request with input data and makes a prediction.
3131 Args:
@@ -132,7 +132,7 @@ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
132132 if isinstance (inputs , list ):
133133 if isinstance (resp , list ) and len (resp ) == len (inputs ):
134134 for it in resp :
135- # Batch size dim is the first it level, dicard it
135+ # Batch size dim is the first it level, discard it
136136 if isinstance (it , list ) and len (it ) == 1 :
137137 new_resp .append (it [0 ])
138138 else :
@@ -160,6 +160,25 @@ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
160160 el ["score" ] = 1
161161 new_resp .append (el )
162162 resp = new_resp
163+ if self .pipeline .task == "zero-shot-classification" :
164+ try :
165+ if isinstance (resp , dict ):
166+ if 'labels' in resp and 'scores' in resp :
167+ labels = resp ['labels' ]
168+ scores = resp ['scores' ]
169+ if len (labels ) == len (scores ):
170+ new_resp = []
171+ for label , score in zip (labels , scores ):
172+ new_resp .append ({"label" : label , "score" : score })
173+ resp = new_resp
174+ else :
175+ raise Exception ("labels and scores do not have the same len, {} != {}" .format (
176+ len (labels ), len (scores )))
177+ else :
178+ raise Exception ("Missing labels or scores key in response dict {}" .format (resp ))
179+ except Exception as e :
180+ logging .logger .warning ("Unable to remap response for api inference compat" )
181+ logging .logger .exception (e )
163182 return resp
164183
165184
0 commit comments