Skip to content

Commit 2ba024a

Browse files
authored
Fix prompt_template override, add tests (#606)
* Remove defaults for getenv * Remove print * missing output * Add tests and fix prepdocs issue * Fix prompt issue, add tests * Fix prompt issue, add tests * Delete unneeded print
1 parent 3c2a266 commit 2ba024a

File tree

4 files changed

+45
-1
lines changed

4 files changed

+45
-1
lines changed

app/backend/approaches/chatreadretrieveread.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ async def run_until_final_call(self, history: list[dict[str, str]], overrides: d
131131
# STEP 3: Generate a contextual and content specific answer using the search results and chat history
132132

133133
# Allow client to replace the entire prompt, or to inject into the exiting prompt using >>>
134-
prompt_override = overrides.get("prompt_override")
134+
prompt_override = overrides.get("prompt_template")
135135
if prompt_override is None:
136136
system_message = self.system_message_chat_conversation.format(injected_prompt="", follow_up_questions_prompt=follow_up_questions_prompt)
137137
elif prompt_override.startswith(">>>"):
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
{
2+
"answer": "The capital of France is Paris.",
3+
"data_points": [
4+
"Benefit_Options-2.pdf: There is a whistleblower policy."
5+
],
6+
"thoughts": "Searched for:<br>capital of France<br><br>Conversations:<br>{'role': 'system', 'content': 'You are a cat.'}<br><br>{'role': 'user', 'content': 'What is the capital of France?\\n\\nSources:\\nBenefit_Options-2.pdf: There is a whistleblower policy.'}"
7+
}
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
{
2+
"answer": "The capital of France is Paris.",
3+
"data_points": [
4+
"Benefit_Options-2.pdf: There is a whistleblower policy."
5+
],
6+
"thoughts": "Searched for:<br>capital of France<br><br>Conversations:<br>{'role': 'system', 'content': \"Assistant helps the company employees with their healthcare plan questions, and questions about the employee handbook. Be brief in your answers.\\nAnswer ONLY with the facts listed in the list of sources below. If there isn't enough information below, say you don't know. Do not generate answers that don't use the sources below. If asking a clarifying question to the user would help, ask the question.\\nFor tabular information return it as an html table. Do not return markdown format. If the question is not in English, answer in the language used in the question.\\nEach source has a name followed by colon and the actual information, always include the source name for each fact you use in the response. Use square brackets to reference the source, e.g. [info1.txt]. Don't combine sources, list each source separately, e.g. [info1.txt][info2.pdf].\\n\\n Meow like a cat.\\n\\n\"}<br><br>{'role': 'user', 'content': 'What is the capital of France?\\n\\nSources:\\nBenefit_Options-2.pdf: There is a whistleblower policy.'}"
7+
}

tests/test_app.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,36 @@ async def test_chat_text_semanticcaptions(client, snapshot):
154154
snapshot.assert_match(json.dumps(result, indent=4), "result.json")
155155

156156

157+
@pytest.mark.asyncio
158+
async def test_chat_prompt_template(client, snapshot):
159+
response = await client.post(
160+
"/chat",
161+
json={
162+
"approach": "rrr",
163+
"history": [{"user": "What is the capital of France?"}],
164+
"overrides": {"retrieval_mode": "text", "prompt_template": "You are a cat."},
165+
},
166+
)
167+
assert response.status_code == 200
168+
result = await response.get_json()
169+
snapshot.assert_match(json.dumps(result, indent=4), "result.json")
170+
171+
172+
@pytest.mark.asyncio
173+
async def test_chat_prompt_template_concat(client, snapshot):
174+
response = await client.post(
175+
"/chat",
176+
json={
177+
"approach": "rrr",
178+
"history": [{"user": "What is the capital of France?"}],
179+
"overrides": {"retrieval_mode": "text", "prompt_template": ">>> Meow like a cat."},
180+
},
181+
)
182+
assert response.status_code == 200
183+
result = await response.get_json()
184+
snapshot.assert_match(json.dumps(result, indent=4), "result.json")
185+
186+
157187
@pytest.mark.asyncio
158188
async def test_chat_hybrid(client, snapshot):
159189
response = await client.post(

0 commit comments

Comments
 (0)