Skip to content

Commit e794248

Browse files
authored
Merge pull request road-core#309 from romartin/AAP-38368-2
AAP-38368: Fixing sample prompt tokens calculations, which may cause further available token calculation issues
2 parents 505ad72 + 4891926 commit e794248

File tree

5 files changed

+44
-10
lines changed

5 files changed

+44
-10
lines changed

ols/src/prompts/prompt_generator.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,11 @@
1313
from ols.customize import prompts
1414

1515

16+
def restructure_rag_context(text: str, model: str) -> str:
17+
"""Restructure rag text by appending special characters.."""
18+
return restructure_rag_context_post(restructure_rag_context_pre(text, model), model)
19+
20+
1621
def restructure_rag_context_pre(text: str, model: str) -> str:
1722
"""Restructure rag text - pre truncation."""
1823
if ModelFamily.GRANITE in model:

ols/src/query_helpers/docs_summarizer.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,11 @@
1212
from ols.app.models.models import RagChunk, SummarizerResponse
1313
from ols.constants import RAG_CONTENT_LIMIT, GenericLLMParameters
1414
from ols.customize import prompts, reranker
15-
from ols.src.prompts.prompt_generator import GeneratePrompt
15+
from ols.src.prompts.prompt_generator import (
16+
GeneratePrompt,
17+
restructure_history,
18+
restructure_rag_context,
19+
)
1620
from ols.src.query_helpers.query_helper import QueryHelper
1721
from ols.utils.token_handler import TokenHandler
1822

@@ -80,7 +84,12 @@ def _prepare_prompt(
8084
# Use sample text for context/history to get complete prompt
8185
# instruction. This is used to calculate available tokens.
8286
temp_prompt, temp_prompt_input = GeneratePrompt(
83-
query, ["sample"], ["ai: sample"], self.system_prompt
87+
# Sample prompt's context/history must be re-structured for the given model,
88+
# to ensure the further right available token calculation.
89+
query,
90+
[restructure_rag_context("sample", self.model)],
91+
[restructure_history("ai: sample", self.model)],
92+
self._system_prompt,
8493
).generate_prompt(self.model)
8594

8695
available_tokens = token_handler.calculate_and_check_available_tokens(

ols/utils/token_handler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def calculate_and_check_available_tokens(
103103
context_window_size - max_tokens_for_response - prompt_token_count
104104
)
105105

106-
if available_tokens <= 0:
106+
if available_tokens < 0:
107107
limit = context_window_size - max_tokens_for_response
108108
raise PromptTooLongError(
109109
f"Prompt length {prompt_token_count} exceeds LLM "

tests/unit/prompts/test_prompt_generator.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,7 @@
1717
from ols.src.prompts.prompt_generator import (
1818
GeneratePrompt,
1919
restructure_history,
20-
restructure_rag_context_post,
21-
restructure_rag_context_pre,
20+
restructure_rag_context,
2221
)
2322

2423
model = [GRANITE_13B_CHAT_V2, GPT35_TURBO]
@@ -33,10 +32,7 @@
3332

3433
def _restructure_prompt_input(rag_context, conversation_history, model):
3534
"""Restructure prompt input."""
36-
rag_formatted = [
37-
restructure_rag_context_post(restructure_rag_context_pre(text, model), model)
38-
for text in rag_context
39-
]
35+
rag_formatted = [restructure_rag_context(text, model) for text in rag_context]
4036
history_formatted = [
4137
restructure_history(history, model) for history in conversation_history
4238
]

tests/unit/query_helpers/test_docs_summarizer.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Unit tests for DocsSummarizer class."""
22

33
import logging
4-
from unittest.mock import ANY, patch
4+
from unittest.mock import ANY, call, patch
55

66
import pytest
77

@@ -121,6 +121,30 @@ def test_summarize_truncation():
121121
assert summary.history_truncated
122122

123123

124+
@patch("ols.utils.token_handler.RAG_SIMILARITY_CUTOFF", 0.4)
125+
@patch("ols.src.query_helpers.docs_summarizer.LLMChain", new=mock_llm_chain(None))
126+
def test_prepare_prompt_context():
127+
"""Basic test for DocsSummarizer to check re-structuring of context for the 'temp' prompt."""
128+
summarizer = DocsSummarizer(llm_loader=mock_llm_loader(None))
129+
question = "What's the ultimate question with answer 42?"
130+
history = ["human: What is Kubernetes?"]
131+
rag_index = MockLlamaIndex()
132+
133+
with patch(
134+
"ols.src.query_helpers.docs_summarizer.restructure_rag_context",
135+
return_value="patched_history",
136+
) as restructure_rag_context:
137+
summarizer.create_response(question, rag_index, history)
138+
restructure_rag_context.assert_has_calls([call("sample", ANY)])
139+
140+
with patch(
141+
"ols.src.query_helpers.docs_summarizer.restructure_history",
142+
return_value="patched_history",
143+
) as restructure_history:
144+
summarizer.create_response(question, rag_index, history)
145+
restructure_history.assert_has_calls([call("ai: sample", ANY)])
146+
147+
124148
@patch("ols.src.query_helpers.docs_summarizer.LLMChain", new=mock_llm_chain(None))
125149
def test_summarize_no_reference_content():
126150
"""Basic test for DocsSummarizer using mocked index and query engine."""

0 commit comments

Comments
 (0)