@@ -117,7 +117,7 @@ def _ensure_token_limit(self, prompt: Union[str, List[Dict[str, str]]]) -> Union
117117
118118
119119def instruction_following_models () -> List [str ]:
120- return ["flan" , "mt0" , "bloomz" , "davinci" ]
120+ return ["flan" , "mt0" , "bloomz" , "davinci" , "opt-iml" ]
121121
122122
123123class StopWordsCriteria (StoppingCriteria ):
@@ -220,7 +220,7 @@ def __init__(
220220
221221 if len (model_input_kwargs ) > 0 :
222222 logger .info ("Using model input kwargs %s in %s" , model_input_kwargs , self .__class__ .__name__ )
223-
223+ self . task_name = get_task ( model_name_or_path , use_auth_token = use_auth_token )
224224 self .pipe = pipeline (
225225 model = model_name_or_path ,
226226 device = self .devices [0 ] if "device_map" not in model_input_kwargs else None ,
@@ -237,23 +237,34 @@ def invoke(self, *args, **kwargs):
237237 It takes a prompt and returns a list of generated text using the local Hugging Face transformers model
238238 :return: A list of generated text.
239239
240- Note: Only kwargs relevant to Text2TextGenerationPipeline are passed to Hugging Face as model_input_kwargs.
241- Other kwargs are ignored.
240+ Note: Only kwargs relevant to Text2TextGenerationPipeline and TextGenerationPipeline are passed to
241+ Hugging Face as model_input_kwargs. Other kwargs are ignored.
242242 """
243243 output : List [Dict [str , str ]] = []
244244 stop_words = kwargs .pop ("stop_words" , None )
245245 top_k = kwargs .pop ("top_k" , None )
246246 if kwargs and "prompt" in kwargs :
247247 prompt = kwargs .pop ("prompt" )
248248
249- # Consider only Text2TextGenerationPipeline relevant, ignore others
250- # For more details refer to Hugging Face Text2TextGenerationPipeline documentation
249+ # Consider only Text2TextGenerationPipeline and TextGenerationPipeline relevant, ignore others
250+ # For more details refer to Hugging Face Text2TextGenerationPipeline and TextGenerationPipeline
251+ # documentation
251252 # TODO resolve these kwargs from the pipeline signature
252253 model_input_kwargs = {
253254 key : kwargs [key ]
254- for key in ["return_tensors" , "return_text" , "clean_up_tokenization_spaces" , "truncation" ]
255+ for key in [
256+ "return_tensors" ,
257+ "return_text" ,
258+ "return_full_text" ,
259+ "clean_up_tokenization_spaces" ,
260+ "truncation" ,
261+ ]
255262 if key in kwargs
256263 }
264+ # Prefer return_full_text is False for text-generation (unless explicitly set)
265+ # Thus only generated text is returned (excluding prompt)
266+ if "text-generation" == self .task_name and "return_full_text" not in model_input_kwargs :
267+ model_input_kwargs ["return_full_text" ] = False
257268 if stop_words :
258269 sw = StopWordsCriteria (tokenizer = self .pipe .tokenizer , stop_words = stop_words )
259270 model_input_kwargs ["stopping_criteria" ] = StoppingCriteriaList ([sw ])
@@ -302,7 +313,7 @@ def _ensure_token_limit(self, prompt: Union[str, List[Dict[str, str]]]) -> Union
302313 def supports (cls , model_name_or_path : str , ** kwargs ) -> bool :
303314 task_name : Optional [str ] = None
304315 try :
305- task_name = get_task (model_name_or_path )
316+ task_name = get_task (model_name_or_path , use_auth_token = kwargs . get ( "use_auth_token" , None ) )
306317 except RuntimeError :
307318 # This will fail for all non-HF models
308319 return False
0 commit comments