Skip to content

Commit 2bb618c

Browse files
committed
Return only generated text for text-generation, update docs, improve unit tests
1 parent aca5654 commit 2bb618c

File tree

3 files changed

+31
-10
lines changed

3 files changed

+31
-10
lines changed

haystack/nodes/prompt/prompt_node.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ class PromptModel(BaseComponent):
165165
"""
166166
The PromptModel class is a component that uses a pre-trained model to perform tasks based on a prompt. Out of
167167
the box, it supports model invocation layers for:
168-
- Hugging Face transformers (all text2text-generation models)
168+
- Hugging Face transformers (all text2text-generation and text-generation models)
169169
- OpenAI InstructGPT models
170170
- Azure OpenAI InstructGPT models
171171
@@ -400,7 +400,8 @@ class PromptNode(BaseComponent):
400400
additional custom model invocation layers.
401401
402402
We recommend using LLMs fine-tuned on a collection of datasets phrased as instructions, otherwise we find that the
403-
LLM does not "follow" prompt instructions well. This is why we recommend using T5 flan or OpenAI InstructGPT models.
403+
LLM does not "follow" prompt instructions well. The list of instructions following models increases every month,
404+
and the current list includes: Flan, OpenAI InstructGPT, opt-iml, bloomz, and mt0 models.
404405
405406
For more details, see [PromptNode](https://docs.haystack.deepset.ai/docs/prompt_node).
406407
"""

haystack/nodes/prompt/providers.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def _ensure_token_limit(self, prompt: Union[str, List[Dict[str, str]]]) -> Union
117117

118118

119119
def instruction_following_models() -> List[str]:
120-
return ["flan", "mt0", "bloomz", "davinci"]
120+
return ["flan", "mt0", "bloomz", "davinci", "opt-iml"]
121121

122122

123123
class StopWordsCriteria(StoppingCriteria):
@@ -220,7 +220,7 @@ def __init__(
220220

221221
if len(model_input_kwargs) > 0:
222222
logger.info("Using model input kwargs %s in %s", model_input_kwargs, self.__class__.__name__)
223-
223+
self.task_name = get_task(model_name_or_path, use_auth_token=use_auth_token)
224224
self.pipe = pipeline(
225225
model=model_name_or_path,
226226
device=self.devices[0] if "device_map" not in model_input_kwargs else None,
@@ -237,23 +237,34 @@ def invoke(self, *args, **kwargs):
237237
It takes a prompt and returns a list of generated text using the local Hugging Face transformers model
238238
:return: A list of generated text.
239239
240-
Note: Only kwargs relevant to Text2TextGenerationPipeline are passed to Hugging Face as model_input_kwargs.
241-
Other kwargs are ignored.
240+
Note: Only kwargs relevant to Text2TextGenerationPipeline and TextGenerationPipeline are passed to
241+
Hugging Face as model_input_kwargs. Other kwargs are ignored.
242242
"""
243243
output: List[Dict[str, str]] = []
244244
stop_words = kwargs.pop("stop_words", None)
245245
top_k = kwargs.pop("top_k", None)
246246
if kwargs and "prompt" in kwargs:
247247
prompt = kwargs.pop("prompt")
248248

249-
# Consider only Text2TextGenerationPipeline relevant, ignore others
250-
# For more details refer to Hugging Face Text2TextGenerationPipeline documentation
249+
# Consider only Text2TextGenerationPipeline and TextGenerationPipeline relevant, ignore others
250+
# For more details refer to Hugging Face Text2TextGenerationPipeline and TextGenerationPipeline
251+
# documentation
251252
# TODO resolve these kwargs from the pipeline signature
252253
model_input_kwargs = {
253254
key: kwargs[key]
254-
for key in ["return_tensors", "return_text", "clean_up_tokenization_spaces", "truncation"]
255+
for key in [
256+
"return_tensors",
257+
"return_text",
258+
"return_full_text",
259+
"clean_up_tokenization_spaces",
260+
"truncation",
261+
]
255262
if key in kwargs
256263
}
264+
# Prefer return_full_text is False for text-generation (unless explicitly set)
265+
# Thus only generated text is returned (excluding prompt)
266+
if "text-generation" == self.task_name and "return_full_text" not in model_input_kwargs:
267+
model_input_kwargs["return_full_text"] = False
257268
if stop_words:
258269
sw = StopWordsCriteria(tokenizer=self.pipe.tokenizer, stop_words=stop_words)
259270
model_input_kwargs["stopping_criteria"] = StoppingCriteriaList([sw])
@@ -302,7 +313,7 @@ def _ensure_token_limit(self, prompt: Union[str, List[Dict[str, str]]]) -> Union
302313
def supports(cls, model_name_or_path: str, **kwargs) -> bool:
303314
task_name: Optional[str] = None
304315
try:
305-
task_name = get_task(model_name_or_path)
316+
task_name = get_task(model_name_or_path, use_auth_token=kwargs.get("use_auth_token", None))
306317
except RuntimeError:
307318
# This will fail for all non-HF models
308319
return False

test/prompt/test_prompt_node.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -409,10 +409,19 @@ def test_streaming_prompt_node():
409409

410410

411411
def test_prompt_node_with_text_generation_model():
412+
# test simple prompting with text generation model
413+
# by default, we force the model not return prompt text
414+
# Thus text-generation models can be used with PromptNode
415+
# just like text2text-generation models
412416
node = PromptNode("bigscience/bigscience-small-testing")
413417
r = node("Hello big science!")
414418
assert len(r[0]) > 0
415419

420+
# test prompting with parameter to return prompt text as well
421+
# users can use this param to get the prompt text and the generated text
422+
r = node("Hello big science!", return_full_text=True)
423+
assert len(r[0]) > 0 and r[0].startswith("Hello big science!")
424+
416425

417426
@pytest.mark.integration
418427
@pytest.mark.parametrize("prompt_model", ["hf", "openai", "azure"], indirect=True)

0 commit comments

Comments
 (0)