Skip to content

Commit 2a360cd

Browse files
committed
Pass max_new_tokens and not max_length for text-generation models
1 parent 66ce061 commit 2a360cd

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

haystack/nodes/prompt/providers.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -261,9 +261,10 @@ def invoke(self, *args, **kwargs):
261261
]
262262
if key in kwargs
263263
}
264+
is_text_generation = "text-generation" == self.task_name
264265
# Prefer return_full_text is False for text-generation (unless explicitly set)
265266
# Thus only generated text is returned (excluding prompt)
266-
if "text-generation" == self.task_name and "return_full_text" not in model_input_kwargs:
267+
if is_text_generation and "return_full_text" not in model_input_kwargs:
267268
model_input_kwargs["return_full_text"] = False
268269
model_input_kwargs["max_new_tokens"] = self.max_length
269270
if stop_words:
@@ -272,7 +273,13 @@ def invoke(self, *args, **kwargs):
272273
if top_k:
273274
model_input_kwargs["num_return_sequences"] = top_k
274275
model_input_kwargs["num_beams"] = top_k
275-
output = self.pipe(prompt, max_length=self.max_length, **model_input_kwargs)
276+
# max_new_tokens is used for text-generation and max_length for text2text-generation
277+
if is_text_generation:
278+
model_input_kwargs["max_new_tokens"] = self.max_length
279+
else:
280+
model_input_kwargs["max_length"] = self.max_length
281+
282+
output = self.pipe(prompt, **model_input_kwargs)
276283
generated_texts = [o["generated_text"] for o in output if "generated_text" in o]
277284

278285
if stop_words:

0 commit comments

Comments
 (0)