Skip to content

Commit 62c5ae8

Browse files
authored
Update to tools parameter (#1175)
* Update to tools parameter * Use tool_choice argument * Update response to iterate through new tool_calls property * Allow tool_calls to be None * Revert func signature as the local is implicitly optional and adding more typing was not the intention of this change * Fix tests as the response structure looks different when using tool_calls
1 parent bec59be commit 62c5ae8

File tree

4 files changed

+147
-22
lines changed

4 files changed

+147
-22
lines changed

app/backend/approaches/chatapproach.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -70,12 +70,17 @@ def get_system_prompt(self, override_prompt: Optional[str], follow_up_questions_
7070

7171
def get_search_query(self, chat_completion: ChatCompletion, user_query: str):
7272
response_message = chat_completion.choices[0].message
73-
if function_call := response_message.function_call:
74-
if function_call.name == "search_sources":
75-
arg = json.loads(function_call.arguments)
76-
search_query = arg.get("search_query", self.NO_RESPONSE)
77-
if search_query != self.NO_RESPONSE:
78-
return search_query
73+
74+
if response_message.tool_calls:
75+
for tool in response_message.tool_calls:
76+
if tool.type != "function":
77+
continue
78+
function = tool.function
79+
if function.name == "search_sources":
80+
arg = json.loads(function.arguments)
81+
search_query = arg.get("search_query", self.NO_RESPONSE)
82+
if search_query != self.NO_RESPONSE:
83+
return search_query
7984
elif query_text := response_message.content:
8085
if query_text.strip() != self.NO_RESPONSE:
8186
return query_text

app/backend/approaches/chatreadretrieveread.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1-
from typing import Any, Coroutine, Literal, Optional, Union, overload
1+
from typing import Any, Coroutine, List, Literal, Optional, Union, overload
22

33
from azure.search.documents.aio import SearchClient
44
from azure.search.documents.models import VectorQuery
55
from openai import AsyncOpenAI, AsyncStream
66
from openai.types.chat import (
77
ChatCompletion,
88
ChatCompletionChunk,
9+
ChatCompletionToolParam,
910
)
1011

1112
from approaches.approach import ThoughtStep
@@ -97,19 +98,22 @@ async def run_until_final_call(
9798
original_user_query = history[-1]["content"]
9899
user_query_request = "Generate search query for: " + original_user_query
99100

100-
functions = [
101+
tools: List[ChatCompletionToolParam] = [
101102
{
102-
"name": "search_sources",
103-
"description": "Retrieve sources from the Azure AI Search index",
104-
"parameters": {
105-
"type": "object",
106-
"properties": {
107-
"search_query": {
108-
"type": "string",
109-
"description": "Query string to retrieve documents from azure search eg: 'Health care plan'",
110-
}
103+
"type": "function",
104+
"function": {
105+
"name": "search_sources",
106+
"description": "Retrieve sources from the Azure AI Search index",
107+
"parameters": {
108+
"type": "object",
109+
"properties": {
110+
"search_query": {
111+
"type": "string",
112+
"description": "Query string to retrieve documents from azure search eg: 'Health care plan'",
113+
}
114+
},
115+
"required": ["search_query"],
111116
},
112-
"required": ["search_query"],
113117
},
114118
}
115119
]
@@ -131,8 +135,8 @@ async def run_until_final_call(
131135
temperature=0.0,
132136
max_tokens=100, # Setting too low risks malformed JSON, setting too high may affect performance
133137
n=1,
134-
functions=functions,
135-
function_call="auto",
138+
tools=tools,
139+
tool_choice="auto",
136140
)
137141

138142
query_text = self.get_search_query(chat_completion, original_user_query)

tests/test_chatapproach.py

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,65 @@ def chat_approach():
2424

2525

2626
def test_get_search_query(chat_approach):
27-
payload = '{"id":"chatcmpl-81JkxYqYppUkPtOAia40gki2vJ9QM","object":"chat.completion","created":1695324963,"model":"gpt-35-turbo","prompt_filter_results":[{"prompt_index":0,"content_filter_results":{"hate":{"filtered":false,"severity":"safe"},"self_harm":{"filtered":false,"severity":"safe"},"sexual":{"filtered":false,"severity":"safe"},"violence":{"filtered":false,"severity":"safe"}}}],"choices":[{"index":0,"finish_reason":"function_call","message":{"content":"this is the query","role":"assistant","function_call":{"name":"search_sources","arguments":"{\\n\\"search_query\\":\\"accesstelemedicineservices\\"\\n}"}},"content_filter_results":{}}],"usage":{"completion_tokens":19,"prompt_tokens":425,"total_tokens":444}}'
27+
payload = """
28+
{
29+
"id": "chatcmpl-81JkxYqYppUkPtOAia40gki2vJ9QM",
30+
"object": "chat.completion",
31+
"created": 1695324963,
32+
"model": "gpt-35-turbo",
33+
"prompt_filter_results": [
34+
{
35+
"prompt_index": 0,
36+
"content_filter_results": {
37+
"hate": {
38+
"filtered": false,
39+
"severity": "safe"
40+
},
41+
"self_harm": {
42+
"filtered": false,
43+
"severity": "safe"
44+
},
45+
"sexual": {
46+
"filtered": false,
47+
"severity": "safe"
48+
},
49+
"violence": {
50+
"filtered": false,
51+
"severity": "safe"
52+
}
53+
}
54+
}
55+
],
56+
"choices": [
57+
{
58+
"index": 0,
59+
"finish_reason": "function_call",
60+
"message": {
61+
"content": "this is the query",
62+
"role": "assistant",
63+
"tool_calls": [
64+
{
65+
"id": "search_sources1235",
66+
"type": "function",
67+
"function": {
68+
"name": "search_sources",
69+
"arguments": "{\\n\\"search_query\\":\\"accesstelemedicineservices\\"\\n}"
70+
}
71+
}
72+
]
73+
},
74+
"content_filter_results": {
75+
76+
}
77+
}
78+
],
79+
"usage": {
80+
"completion_tokens": 19,
81+
"prompt_tokens": 425,
82+
"total_tokens": 444
83+
}
84+
}
85+
"""
2886
default_query = "hello"
2987
chatcompletions = ChatCompletion.model_validate(json.loads(payload), strict=False)
3088
query = chat_approach.get_search_query(chatcompletions, default_query)

tests/test_chatvisionapproach.py

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,65 @@ def test_build_filter(chat_approach):
6767

6868

6969
def test_get_search_query(chat_approach):
70-
payload = '{"id":"chatcmpl-81JkxYqYppUkPtOAia40gki2vJ9QM","object":"chat.completion","created":1695324963,"model":"gpt-4v","prompt_filter_results":[{"prompt_index":0,"content_filter_results":{"hate":{"filtered":false,"severity":"safe"},"self_harm":{"filtered":false,"severity":"safe"},"sexual":{"filtered":false,"severity":"safe"},"violence":{"filtered":false,"severity":"safe"}}}],"choices":[{"index":0,"finish_reason":"function_call","message":{"content":"this is the query","role":"assistant","function_call":{"name":"search_sources","arguments":"{\\n\\"search_query\\":\\"accesstelemedicineservices\\"\\n}"}},"content_filter_results":{}}],"usage":{"completion_tokens":19,"prompt_tokens":425,"total_tokens":444}}'
70+
payload = """
71+
{
72+
"id": "chatcmpl-81JkxYqYppUkPtOAia40gki2vJ9QM",
73+
"object": "chat.completion",
74+
"created": 1695324963,
75+
"model": "gpt-35-turbo",
76+
"prompt_filter_results": [
77+
{
78+
"prompt_index": 0,
79+
"content_filter_results": {
80+
"hate": {
81+
"filtered": false,
82+
"severity": "safe"
83+
},
84+
"self_harm": {
85+
"filtered": false,
86+
"severity": "safe"
87+
},
88+
"sexual": {
89+
"filtered": false,
90+
"severity": "safe"
91+
},
92+
"violence": {
93+
"filtered": false,
94+
"severity": "safe"
95+
}
96+
}
97+
}
98+
],
99+
"choices": [
100+
{
101+
"index": 0,
102+
"finish_reason": "function_call",
103+
"message": {
104+
"content": "this is the query",
105+
"role": "assistant",
106+
"tool_calls": [
107+
{
108+
"id": "search_sources1235",
109+
"type": "function",
110+
"function": {
111+
"name": "search_sources",
112+
"arguments": "{\\n\\"search_query\\":\\"accesstelemedicineservices\\"\\n}"
113+
}
114+
}
115+
]
116+
},
117+
"content_filter_results": {
118+
119+
}
120+
}
121+
],
122+
"usage": {
123+
"completion_tokens": 19,
124+
"prompt_tokens": 425,
125+
"total_tokens": 444
126+
}
127+
}
128+
"""
71129
default_query = "hello"
72130
chatcompletions = ChatCompletion.model_validate(json.loads(payload), strict=False)
73131
query = chat_approach.get_search_query(chatcompletions, default_query)

0 commit comments

Comments
 (0)