Skip to content

Commit 1144039

Browse files
authored
fix: Set model_max_length in the Tokenizer of DefaultPromptHandler (#5596)
* Set model_max_length in tokenizer in prompt handler * Add release note
1 parent 67da275 commit 1144039

File tree

3 files changed

+15
-0
lines changed

3 files changed

+15
-0
lines changed

haystack/nodes/prompt/invocation_layer/handlers.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ class DefaultPromptHandler:
6363

6464
def __init__(self, model_name_or_path: str, model_max_length: int, max_length: int = 100):
6565
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
66+
self.tokenizer.model_max_length = model_max_length
6667
self.model_max_length = model_max_length
6768
self.max_length = max_length
6869

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
---
2+
fixes:
3+
- |
4+
Fix model_max_length not being set in the Tokenizer in DefaultPromptHandler.

test/prompt/test_handlers.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,13 @@ def test_prompt_handler_negative():
5757
}
5858

5959

60+
@pytest.mark.unit
61+
@patch("haystack.nodes.prompt.invocation_layer.handlers.AutoTokenizer.from_pretrained")
62+
def test_prompt_handler_model_max_length_set_in_tokenizer(mock_tokenizer):
63+
prompt_handler = DefaultPromptHandler(model_name_or_path="model_path", model_max_length=10, max_length=3)
64+
assert prompt_handler.tokenizer.model_max_length == 10
65+
66+
6067
@pytest.mark.integration
6168
def test_prompt_handler_basics():
6269
handler = DefaultPromptHandler(model_name_or_path="gpt2", model_max_length=20, max_length=10)
@@ -65,6 +72,9 @@ def test_prompt_handler_basics():
6572
handler = DefaultPromptHandler(model_name_or_path="gpt2", model_max_length=20)
6673
assert handler.max_length == 100
6774

75+
# test model_max_length is set in tokenizer
76+
assert handler.tokenizer.model_max_length == 20
77+
6878

6979
@pytest.mark.integration
7080
def test_gpt2_prompt_handler():

0 commit comments

Comments
 (0)