|
6 | 6 |
|
7 | 7 | import sseclient |
8 | 8 | import torch |
9 | | -from transformers import ( |
10 | | - pipeline, |
11 | | - AutoConfig, |
12 | | - StoppingCriteriaList, |
13 | | - StoppingCriteria, |
14 | | - PreTrainedTokenizer, |
15 | | - PreTrainedTokenizerFast, |
16 | | -) |
17 | | -from transformers.models.auto.modeling_auto import MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES |
| 9 | +from transformers import pipeline, StoppingCriteriaList, StoppingCriteria, PreTrainedTokenizer, PreTrainedTokenizerFast |
| 10 | +from transformers.pipelines import get_task |
18 | 11 |
|
19 | 12 | from haystack.errors import OpenAIError |
20 | 13 | from haystack.modeling.utils import initialize_device_settings |
@@ -225,7 +218,6 @@ def __init__( |
225 | 218 | logger.info("Using model input kwargs %s in %s", model_input_kwargs, self.__class__.__name__) |
226 | 219 |
|
227 | 220 | self.pipe = pipeline( |
228 | | - "text2text-generation", |
229 | 221 | model=model_name_or_path, |
230 | 222 | device=self.devices[0] if "device_map" not in model_input_kwargs else None, |
231 | 223 | use_auth_token=self.use_auth_token, |
@@ -304,22 +296,21 @@ def _ensure_token_limit(self, prompt: Union[str, List[Dict[str, str]]]) -> Union |
304 | 296 |
|
305 | 297 | @classmethod |
306 | 298 | def supports(cls, model_name_or_path: str, **kwargs) -> bool: |
| 299 | + task_name: Optional[str] = None |
307 | 300 | try: |
308 | | - config = AutoConfig.from_pretrained(model_name_or_path) |
309 | | - except OSError: |
| 301 | + task_name = get_task(model_name_or_path) |
| 302 | + except RuntimeError: |
310 | 303 | # This is needed so OpenAI models are skipped over |
311 | 304 | return False |
312 | 305 |
|
313 | | - if not all(m in model_name_or_path for m in ["flan", "t5"]): |
| 306 | + if not any(m in model_name_or_path for m in ["flan", "mt0", "bloomz"]): |
314 | 307 | logger.warning( |
315 | 308 | "PromptNode has been potentially initialized with a language model not fine-tuned on instruction following tasks. " |
316 | 309 | "Many of the default prompts and PromptTemplates will likely not work as intended. " |
317 | 310 | "Use custom prompts and PromptTemplates specific to the %s model", |
318 | 311 | model_name_or_path, |
319 | 312 | ) |
320 | | - |
321 | | - supported_models = list(MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES.values()) |
322 | | - return config.architectures[0] in supported_models |
| 313 | + return task_name in ["text2text-generation", "text-generation"] |
323 | 314 |
|
324 | 315 |
|
325 | 316 | class OpenAIInvocationLayer(PromptModelInvocationLayer): |
|
0 commit comments