Skip to content

Commit 7c4242b

Browse files
committed
Few shots for tools
1 parent f5a5479 commit 7c4242b

File tree

2 files changed

+74
-5
lines changed

2 files changed

+74
-5
lines changed

src/openai_messages_token_helper/message_builder.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
ChatCompletionNamedToolChoiceParam,
1111
ChatCompletionRole,
1212
ChatCompletionSystemMessageParam,
13+
ChatCompletionToolMessageParam,
1314
ChatCompletionToolParam,
1415
ChatCompletionUserMessageParam,
1516
)
@@ -48,7 +49,12 @@ def all_messages(self) -> list[ChatCompletionMessageParam]:
4849
return [self.system_message] + self.messages
4950

5051
def insert_message(
51-
self, role: ChatCompletionRole, content: Union[str, Iterable[ChatCompletionContentPartParam]], index: int = 0
52+
self,
53+
role: ChatCompletionRole,
54+
content: Union[str, Iterable[ChatCompletionContentPartParam]],
55+
index: int = 0,
56+
tool_calls: Optional[list[ChatCompletionToolParam]] = None,
57+
tool_call_id: Optional[str] = None,
5258
):
5359
"""
5460
Inserts a message into the conversation at the specified index,
@@ -63,8 +69,14 @@ def insert_message(
6369
message = ChatCompletionUserMessageParam(role="user", content=normalize_content(content))
6470
elif role == "assistant" and isinstance(content, str):
6571
message = ChatCompletionAssistantMessageParam(role="assistant", content=normalize_content(content))
72+
elif role == "assistant" and tool_calls is not None:
73+
message = ChatCompletionAssistantMessageParam(role="assistant", tool_calls=tool_calls)
74+
elif role == "tool" and tool_call_id is not None:
75+
message = ChatCompletionToolMessageParam(
76+
role="tool", tool_call_id=tool_call_id, content=normalize_content(content)
77+
)
6678
else:
67-
raise ValueError(f"Invalid role: {role}")
79+
raise ValueError("Invalid message for builder")
6880
self.messages.insert(index, message)
6981

7082

@@ -102,9 +114,11 @@ def build_messages(
102114
message_builder = _MessageBuilder(system_prompt)
103115

104116
for shot in reversed(few_shots):
105-
if shot["role"] is None or shot["content"] is None:
106-
raise ValueError("Few-shot messages must have both role and content")
107-
message_builder.insert_message(shot["role"], shot["content"])
117+
if shot["role"] is None or (shot.get("content") is None and shot.get("tool_calls") is None):
118+
raise ValueError("Few-shot messages must have role and either content or tool_calls")
119+
message_builder.insert_message(
120+
shot["role"], shot.get("content"), tool_calls=shot.get("tool_calls"), tool_call_id=shot.get("tool_call_id")
121+
)
108122

109123
append_index = len(few_shots)
110124

tests/test_messagebuilder.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,61 @@ def test_messagebuilder_system_fewshots():
200200
assert messages[5]["content"] == user_message_pm["message"]["content"]
201201

202202

203+
def test_messagebuilder_system_fewshotstools():
204+
messages = build_messages(
205+
model="gpt-35-turbo",
206+
system_prompt=system_message_short["message"]["content"],
207+
new_user_content=user_message_pm["message"]["content"],
208+
past_messages=[],
209+
few_shots=[
210+
{"role": "user", "content": "good options for climbing gear that can be used outside?"},
211+
{
212+
"role": "assistant",
213+
"tool_calls": [
214+
{
215+
"id": "call_abc123",
216+
"type": "function",
217+
"function": {
218+
"arguments": '{"search_query":"climbing gear outside"}',
219+
"name": "search_database",
220+
},
221+
}
222+
],
223+
},
224+
{
225+
"role": "tool",
226+
"tool_call_id": "call_abc123",
227+
"content": "Search results for climbing gear that can be used outside: ...",
228+
},
229+
{"role": "user", "content": "are there any shoes less than $50?"},
230+
{
231+
"role": "assistant",
232+
"tool_calls": [
233+
{
234+
"id": "call_abc456",
235+
"type": "function",
236+
"function": {
237+
"arguments": '{"search_query":"shoes","price_filter":{"comparison_operator":"<","value":50}}',
238+
"name": "search_database",
239+
},
240+
}
241+
],
242+
},
243+
{"role": "tool", "tool_call_id": "call_abc456", "content": "Search results for shoes cheaper than 50: ..."},
244+
],
245+
)
246+
# Make sure messages are in the right order
247+
assert messages[0]["role"] == "system"
248+
assert messages[1]["role"] == "user"
249+
assert messages[2]["role"] == "assistant"
250+
assert messages[3]["role"] == "tool"
251+
assert messages[4]["role"] == "user"
252+
assert messages[5]["role"] == "assistant"
253+
assert messages[6]["role"] == "tool"
254+
assert messages[7]["role"] == "user"
255+
assert messages[7]["content"] == user_message_pm["message"]["content"]
256+
257+
203258
def test_messagebuilder_system_tools():
204259
"""Tests that the system message token count is considered."""
205260
messages = build_messages(

0 commit comments

Comments
 (0)