Skip to content

Commit d048366

Browse files
committed
Proper HF pipeline task type detection, new supports impl
1 parent 3272e2b commit d048366

File tree

1 file changed

+7
-16
lines changed

1 file changed

+7
-16
lines changed

haystack/nodes/prompt/providers.py

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,8 @@
66

77
import sseclient
88
import torch
9-
from transformers import (
10-
pipeline,
11-
AutoConfig,
12-
StoppingCriteriaList,
13-
StoppingCriteria,
14-
PreTrainedTokenizer,
15-
PreTrainedTokenizerFast,
16-
)
17-
from transformers.models.auto.modeling_auto import MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
9+
from transformers import pipeline, StoppingCriteriaList, StoppingCriteria, PreTrainedTokenizer, PreTrainedTokenizerFast
10+
from transformers.pipelines import get_task
1811

1912
from haystack.errors import OpenAIError
2013
from haystack.modeling.utils import initialize_device_settings
@@ -225,7 +218,6 @@ def __init__(
225218
logger.info("Using model input kwargs %s in %s", model_input_kwargs, self.__class__.__name__)
226219

227220
self.pipe = pipeline(
228-
"text2text-generation",
229221
model=model_name_or_path,
230222
device=self.devices[0] if "device_map" not in model_input_kwargs else None,
231223
use_auth_token=self.use_auth_token,
@@ -304,22 +296,21 @@ def _ensure_token_limit(self, prompt: Union[str, List[Dict[str, str]]]) -> Union
304296

305297
@classmethod
306298
def supports(cls, model_name_or_path: str, **kwargs) -> bool:
299+
task_name: Optional[str] = None
307300
try:
308-
config = AutoConfig.from_pretrained(model_name_or_path)
309-
except OSError:
301+
task_name = get_task(model_name_or_path)
302+
except RuntimeError:
310303
# This is needed so OpenAI models are skipped over
311304
return False
312305

313-
if not all(m in model_name_or_path for m in ["flan", "t5"]):
306+
if not any(m in model_name_or_path for m in ["flan", "mt0", "bloomz"]):
314307
logger.warning(
315308
"PromptNode has been potentially initialized with a language model not fine-tuned on instruction following tasks. "
316309
"Many of the default prompts and PromptTemplates will likely not work as intended. "
317310
"Use custom prompts and PromptTemplates specific to the %s model",
318311
model_name_or_path,
319312
)
320-
321-
supported_models = list(MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES.values())
322-
return config.architectures[0] in supported_models
313+
return task_name in ["text2text-generation", "text-generation"]
323314

324315

325316
class OpenAIInvocationLayer(PromptModelInvocationLayer):

0 commit comments

Comments
 (0)