@@ -102,22 +102,18 @@ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
102102 "or `candidateLabels`."
103103 )
104104
105- if api_inference_compat () and self .pipeline .task in ["text-classification" , "token-classification" ] and \
106- isinstance (inputs , str ):
107- inputs = [inputs ]
108- if self .pipeline .task == "text-classification" and "top_k" not in parameters :
109- top_k = os .environ .get ("DEFAULT_TOP_K" , 5 )
110- parameters ["top_k" ] = top_k
111- resp = self .pipeline (inputs , ** parameters )
112- # # We don't want to return {}
113- if isinstance (resp , list ) and len (resp ) > 0 :
114- if not isinstance (resp [0 ], list ):
115- return [resp ]
116- else :
117- return resp
118- else :
105+ if api_inference_compat ():
106+ if self .pipeline .task == "text-classification" and isinstance (inputs , str ):
107+ inputs = [inputs ]
108+ parameters .setdefault ("top_k" , os .environ .get ("DEFAULT_TOP_K" , 5 ))
109+ resp = self .pipeline (inputs , ** parameters )
110+ # # We don't want to return {}
111+ if isinstance (resp , list ) and len (resp ) > 0 :
112+ if not isinstance (resp [0 ], list ):
113+ return [resp ]
119114 return resp
120-
115+ if self .pipeline .task == "token-classification" :
116+ parameters .setdefault ("aggregation_strategy" , os .environ .get ("DEFAULT_AGGREGATION_STRATEGY" , "simple" ))
121117 return (
122118 self .pipeline (** inputs , ** parameters ) if isinstance (inputs , dict ) else self .pipeline (inputs , ** parameters ) # type: ignore
123119 )
0 commit comments