Skip to content

Commit d9ece3d

Browse files
committed
Further simplify things
1 parent d048366 commit d9ece3d

File tree

2 files changed

+15
-9
lines changed

2 files changed

+15
-9
lines changed

haystack/nodes/prompt/prompt_node.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99

1010
from haystack import MultiLabel
1111
from haystack.nodes.base import BaseComponent
12-
from haystack.nodes.prompt.providers import PromptModelInvocationLayer
12+
from haystack.nodes.prompt.providers import PromptModelInvocationLayer, instruction_following_models
13+
1314
from haystack.schema import Document
1415
from haystack.telemetry_2 import send_event
1516

@@ -215,6 +216,14 @@ def __init__(
215216

216217
self.model_kwargs = model_kwargs if model_kwargs else {}
217218
self.model_invocation_layer = self.create_invocation_layer(invocation_layer_class=invocation_layer_class)
219+
is_instruction_following: bool = any(m in model_name_or_path for m in instruction_following_models())
220+
if not is_instruction_following:
221+
logger.warning(
222+
"PromptNode has been potentially initialized with a language model not fine-tuned on instruction following tasks. "
223+
"Many of the default prompts and PromptTemplates will likely not work as intended. "
224+
"Use custom prompts and PromptTemplates specific to the %s model",
225+
model_name_or_path,
226+
)
218227

219228
def create_invocation_layer(
220229
self, invocation_layer_class: Optional[Type[PromptModelInvocationLayer]]

haystack/nodes/prompt/providers.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,10 @@ def _ensure_token_limit(self, prompt: Union[str, List[Dict[str, str]]]) -> Union
116116
pass
117117

118118

119+
def instruction_following_models() -> List[str]:
120+
return ["flan", "mt0", "bloomz", "davinci"]
121+
122+
119123
class StopWordsCriteria(StoppingCriteria):
120124
"""
121125
Stops text generation if any one of the stop words is generated.
@@ -300,16 +304,9 @@ def supports(cls, model_name_or_path: str, **kwargs) -> bool:
300304
try:
301305
task_name = get_task(model_name_or_path)
302306
except RuntimeError:
303-
# This is needed so OpenAI models are skipped over
307+
# This will fail for all non-HF models
304308
return False
305309

306-
if not any(m in model_name_or_path for m in ["flan", "mt0", "bloomz"]):
307-
logger.warning(
308-
"PromptNode has been potentially initialized with a language model not fine-tuned on instruction following tasks. "
309-
"Many of the default prompts and PromptTemplates will likely not work as intended. "
310-
"Use custom prompts and PromptTemplates specific to the %s model",
311-
model_name_or_path,
312-
)
313310
return task_name in ["text2text-generation", "text-generation"]
314311

315312

0 commit comments

Comments
 (0)