Skip to content

Commit 7b7e6ce

Browse files
srbalakrpamelafox
andauthored
Stabilize search query generation (#652)
* Use function call to stabilize search query * minor changes * sort imports * fix ruff * s * Update app/backend/approaches/chatreadretrieveread.py Co-authored-by: Pamela Fox <[email protected]> * s * blacj format * add test * save --------- Co-authored-by: Pamela Fox <[email protected]>
1 parent 595578c commit 7b7e6ce

File tree

3 files changed

+67
-8
lines changed

3 files changed

+67
-8
lines changed

app/backend/app.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ async def setup_clients():
193193
if OPENAI_HOST == "azure":
194194
openai.api_type = "azure_ad"
195195
openai.api_base = f"https://{AZURE_OPENAI_SERVICE}.openai.azure.com"
196-
openai.api_version = "2023-05-15"
196+
openai.api_version = "2023-07-01-preview"
197197
openai_token = await azure_credential.get_token("https://cognitiveservices.azure.com/.default")
198198
openai.api_key = openai_token.token
199199
# Store on app.config for later use inside requests

app/backend/approaches/chatreadretrieveread.py

Lines changed: 43 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import json
12
from typing import Any, AsyncGenerator
23

34
import openai
@@ -15,6 +16,8 @@ class ChatReadRetrieveReadApproach:
1516
USER = "user"
1617
ASSISTANT = "assistant"
1718

19+
NO_RESPONSE = "0"
20+
1821
"""
1922
Simple retrieve-then-read implementation, using the Cognitive Search and OpenAI APIs directly. It first retrieves
2023
top documents from search, then constructs a prompt with them, and then uses OpenAI to generate an completion
@@ -33,6 +36,7 @@ class ChatReadRetrieveReadApproach:
3336
Only generate questions and do not generate any text before or after the questions, such as 'Next Questions'"""
3437

3538
query_prompt_template = """Below is a history of the conversation so far, and a new question asked by the user that needs to be answered by searching in a knowledge base about employee healthcare plans and the employee handbook.
39+
You have access to Azure Cognitive Search index with 100's of documents.
3640
Generate a search query based on the conversation and the new question.
3741
Do not include cited source filenames and document names e.g info.txt or doc.pdf in the search query terms.
3842
Do not include any text inside [] or <<>> in the search query terms.
@@ -78,16 +82,33 @@ async def run_until_final_call(
7882
exclude_category = overrides.get("exclude_category") or None
7983
filter = "category ne '{}'".format(exclude_category.replace("'", "''")) if exclude_category else None
8084

81-
user_q = "Generate search query for: " + history[-1]["user"]
85+
user_query_request = "Generate search query for: " + history[-1]["user"]
86+
87+
functions = [
88+
{
89+
"name": "search_sources",
90+
"description": "Retrieve sources from the Azure Cognitive Search index",
91+
"parameters": {
92+
"type": "object",
93+
"properties": {
94+
"search_query": {
95+
"type": "string",
96+
"description": "Query string to retrieve documents from azure search eg: 'Health care plan'",
97+
}
98+
},
99+
"required": ["search_query"],
100+
},
101+
}
102+
]
82103

83104
# STEP 1: Generate an optimized keyword search query based on the chat history and the last question
84105
messages = self.get_messages_from_history(
85106
self.query_prompt_template,
86107
self.chatgpt_model,
87108
history,
88-
user_q,
109+
user_query_request,
89110
self.query_prompt_few_shots,
90-
self.chatgpt_token_limit - len(user_q),
111+
self.chatgpt_token_limit - len(user_query_request),
91112
)
92113

93114
chatgpt_args = {"deployment_id": self.chatgpt_deployment} if self.openai_host == "azure" else {}
@@ -98,11 +119,11 @@ async def run_until_final_call(
98119
temperature=0.0,
99120
max_tokens=32,
100121
n=1,
122+
functions=functions,
123+
function_call="auto",
101124
)
102125

103-
query_text = chat_completion.choices[0].message.content
104-
if query_text.strip() == "0":
105-
query_text = history[-1]["user"] # Use the last user input if we failed to generate a better query
126+
query_text = self.get_search_query(chat_completion, history[-1]["user"])
106127

107128
# STEP 2: Retrieve relevant documents from the search index with the GPT optimized query
108129

@@ -186,6 +207,7 @@ async def run_until_final_call(
186207
"thoughts": f"Searched for:<br>{query_text}<br><br>Conversations:<br>"
187208
+ msg_to_display.replace("\n", "<br>"),
188209
}
210+
189211
chat_coroutine = openai.ChatCompletion.acreate(
190212
**chatgpt_args,
191213
model=self.chatgpt_model,
@@ -199,7 +221,8 @@ async def run_until_final_call(
199221

200222
async def run_without_streaming(self, history: list[dict[str, str]], overrides: dict[str, Any]) -> dict[str, Any]:
201223
extra_info, chat_coroutine = await self.run_until_final_call(history, overrides, should_stream=False)
202-
chat_content = (await chat_coroutine).choices[0].message.content
224+
chat_resp = await chat_coroutine
225+
chat_content = chat_resp.choices[0].message.content
203226
extra_info["answer"] = chat_content
204227
return extra_info
205228

@@ -242,3 +265,16 @@ def get_messages_from_history(
242265

243266
messages = message_builder.messages
244267
return messages
268+
269+
def get_search_query(self, chat_completion: dict[str, any], user_query: str):
270+
response_message = chat_completion["choices"][0]["message"]
271+
if function_call := response_message.get("function_call"):
272+
if function_call["name"] == "search_sources":
273+
arg = json.loads(function_call["arguments"])
274+
search_query = arg.get("search_query", self.NO_RESPONSE)
275+
if search_query != self.NO_RESPONSE:
276+
return search_query
277+
elif query_text := response_message.get("content"):
278+
if query_text.strip() != self.NO_RESPONSE:
279+
return query_text
280+
return user_query

tests/test_chatapproach.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import json
2+
3+
from approaches.chatreadretrieveread import ChatReadRetrieveReadApproach
4+
5+
6+
def test_get_search_query():
7+
chat_approach = ChatReadRetrieveReadApproach(None, "", "gpt-35-turbo", "gpt-35-turbo", "", "", "", "")
8+
9+
payload = '{"id":"chatcmpl-81JkxYqYppUkPtOAia40gki2vJ9QM","object":"chat.completion","created":1695324963,"model":"gpt-35-turbo","prompt_filter_results":[{"prompt_index":0,"content_filter_results":{"hate":{"filtered":false,"severity":"safe"},"self_harm":{"filtered":false,"severity":"safe"},"sexual":{"filtered":false,"severity":"safe"},"violence":{"filtered":false,"severity":"safe"}}}],"choices":[{"index":0,"finish_reason":"function_call","message":{"role":"assistant","function_call":{"name":"search_sources","arguments":"{\\n\\"search_query\\":\\"accesstelemedicineservices\\"\\n}"}},"content_filter_results":{}}],"usage":{"completion_tokens":19,"prompt_tokens":425,"total_tokens":444}}'
10+
default_query = "hello"
11+
query = chat_approach.get_search_query(json.loads(payload), default_query)
12+
13+
assert query == "accesstelemedicineservices"
14+
15+
16+
def test_get_search_query_returns_default():
17+
chat_approach = ChatReadRetrieveReadApproach(None, "", "gpt-35-turbo", "gpt-35-turbo", "", "", "", "")
18+
19+
payload = '{"id":"chatcmpl-81JkxYqYppUkPtOAia40gki2vJ9QM","object":"chat.completion","created":1695324963,"model":"gpt-35-turbo","prompt_filter_results":[{"prompt_index":0,"content_filter_results":{"hate":{"filtered":false,"severity":"safe"},"self_harm":{"filtered":false,"severity":"safe"},"sexual":{"filtered":false,"severity":"safe"},"violence":{"filtered":false,"severity":"safe"}}}],"choices":[{"index":0,"finish_reason":"function_call","message":{"role":"assistant"},"content_filter_results":{}}],"usage":{"completion_tokens":19,"prompt_tokens":425,"total_tokens":444}}'
20+
default_query = "hello"
21+
query = chat_approach.get_search_query(json.loads(payload), default_query)
22+
23+
assert query == default_query

0 commit comments

Comments
 (0)