Skip to content

Commit 5f1256a

Browse files
authored
feat: generators (2.0) (#5690)
* add generators module * add tests for module helper * reno * add another test * move into openai * improve tests
1 parent 6787ad2 commit 5f1256a

File tree

6 files changed

+68
-0
lines changed

6 files changed

+68
-0
lines changed

haystack/preview/components/generators/__init__.py

Whitespace-only changes.

haystack/preview/components/generators/openai/__init__.py

Whitespace-only changes.
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import logging
2+
3+
from haystack.preview.lazy_imports import LazyImport
4+
5+
with LazyImport("Run 'pip install tiktoken'") as tiktoken_import:
6+
import tiktoken
7+
8+
9+
logger = logging.getLogger(__name__)
10+
11+
12+
def enforce_token_limit(prompt: str, tokenizer: "tiktoken.Encoding", max_tokens_limit: int) -> str:
13+
"""
14+
Ensure that the length of the prompt is within the max tokens limit of the model.
15+
If needed, truncate the prompt text so that it fits within the limit.
16+
17+
:param prompt: Prompt text to be sent to the generative model.
18+
:param tokenizer: The tokenizer used to encode the prompt.
19+
:param max_tokens_limit: The max tokens limit of the model.
20+
:return: The prompt text that fits within the max tokens limit of the model.
21+
"""
22+
tiktoken_import.check()
23+
tokens = tokenizer.encode(prompt)
24+
tokens_count = len(tokens)
25+
if tokens_count > max_tokens_limit:
26+
logger.warning(
27+
"The prompt has been truncated from %s tokens to %s tokens to fit within the max token limit. "
28+
"Reduce the length of the prompt to prevent it from being cut off.",
29+
tokens_count,
30+
max_tokens_limit,
31+
)
32+
prompt = tokenizer.decode(tokens[:max_tokens_limit])
33+
return prompt
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
preview:
2+
- Add generators module for LLM generator components.
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import pytest
2+
3+
from haystack.preview.components.generators.openai._helpers import enforce_token_limit
4+
5+
6+
@pytest.mark.unit
7+
def test_enforce_token_limit_above_limit(caplog, mock_tokenizer):
8+
prompt = enforce_token_limit("This is a test prompt.", tokenizer=mock_tokenizer, max_tokens_limit=3)
9+
assert prompt == "This is a"
10+
assert caplog.records[0].message == (
11+
"The prompt has been truncated from 5 tokens to 3 tokens to fit within the max token "
12+
"limit. Reduce the length of the prompt to prevent it from being cut off."
13+
)
14+
15+
16+
@pytest.mark.unit
17+
def test_enforce_token_limit_below_limit(caplog, mock_tokenizer):
18+
prompt = enforce_token_limit("This is a test prompt.", tokenizer=mock_tokenizer, max_tokens_limit=100)
19+
assert prompt == "This is a test prompt."
20+
assert not caplog.records

test/preview/conftest.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from unittest.mock import Mock
2+
import pytest
3+
4+
5+
@pytest.fixture()
6+
def mock_tokenizer():
7+
"""
8+
Tokenizes the string by splitting on spaces.
9+
"""
10+
tokenizer = Mock()
11+
tokenizer.encode = lambda text: text.split()
12+
tokenizer.decode = lambda tokens: " ".join(tokens)
13+
return tokenizer

0 commit comments

Comments
 (0)