Skip to content

Commit 71f9bfe

Browse files
authored
Merge pull request #14 from pamelafox/fewshottols
Support tools in few shots
2 parents 0d7b9e7 + 1d27d9a commit 71f9bfe

File tree

6 files changed

+106
-14
lines changed

6 files changed

+106
-14
lines changed

.github/workflows/python.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,6 @@ jobs:
3030
run: black . --check --verbose
3131
- name: Run unit tests
3232
run: |
33-
python3 -m pytest -s -vv --cov --cov-fail-under=98
33+
python3 -m pytest -s -vv --cov --cov-fail-under=97
3434
- name: Run type checks
3535
run: mypy .

CHANGELOG.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,14 @@
22

33
All notable changes to this project will be documented in this file.
44

5+
## [0.1.7] - Aug 3, 2024
6+
7+
- Fix bug where you couldn't pass in example tool calls in `few_shots` to `build_messages`.
8+
9+
## [0.1.6] - Aug 2, 2024
10+
11+
- Fix bug where you couldn't pass in `tools` and `default_to_cl100k` to True with a non-OpenAI model.
12+
513
## [0.1.5] - June 4, 2024
614

715
- Remove spurious `print` call when counting tokens for function calling.

CONTRIBUTING.md

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,16 @@ python3 -m pytest
1717

1818
## Publishing
1919

20-
Publish to PyPi:
20+
1. Update the CHANGELOG with description of changes
2121

22-
```shell
23-
export FLIT_USERNAME=__token__
24-
export FLIT_PASSWORD=<your-pypi-token>
25-
flit publish
26-
```
22+
2. Update the version number in pyproject.toml
23+
24+
3. Push the changes to the main branch
25+
26+
4. Publish to PyPi:
27+
28+
```shell
29+
export FLIT_USERNAME=__token__
30+
export FLIT_PASSWORD=<your-pypi-token>
31+
flit publish
32+
```

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[project]
22
name = "openai-messages-token-helper"
33
description = "A helper library for estimating tokens used by messages sent through OpenAI Chat Completions API."
4-
version = "0.1.5"
4+
version = "0.1.7"
55
authors = [{name = "Pamela Fox"}]
66
requires-python = ">=3.9"
77
readme = "README.md"

src/openai_messages_token_helper/message_builder.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,21 @@
77
ChatCompletionAssistantMessageParam,
88
ChatCompletionContentPartParam,
99
ChatCompletionMessageParam,
10+
ChatCompletionMessageToolCallParam,
1011
ChatCompletionNamedToolChoiceParam,
1112
ChatCompletionRole,
1213
ChatCompletionSystemMessageParam,
14+
ChatCompletionToolMessageParam,
1315
ChatCompletionToolParam,
1416
ChatCompletionUserMessageParam,
1517
)
1618

1719
from .model_helper import count_tokens_for_message, count_tokens_for_system_and_tools, get_token_limit
1820

1921

20-
def normalize_content(content: Union[str, Iterable[ChatCompletionContentPartParam]]):
22+
def normalize_content(content: Union[str, Iterable[ChatCompletionContentPartParam], None]):
23+
if content is None:
24+
return None
2125
if isinstance(content, str):
2226
return unicodedata.normalize("NFC", content)
2327
else:
@@ -48,7 +52,12 @@ def all_messages(self) -> list[ChatCompletionMessageParam]:
4852
return [self.system_message] + self.messages
4953

5054
def insert_message(
51-
self, role: ChatCompletionRole, content: Union[str, Iterable[ChatCompletionContentPartParam]], index: int = 0
55+
self,
56+
role: ChatCompletionRole,
57+
content: Union[str, Iterable[ChatCompletionContentPartParam], None],
58+
index: int = 0,
59+
tool_calls: Optional[Iterable[ChatCompletionMessageToolCallParam]] = None,
60+
tool_call_id: Optional[str] = None,
5261
):
5362
"""
5463
Inserts a message into the conversation at the specified index,
@@ -63,8 +72,14 @@ def insert_message(
6372
message = ChatCompletionUserMessageParam(role="user", content=normalize_content(content))
6473
elif role == "assistant" and isinstance(content, str):
6574
message = ChatCompletionAssistantMessageParam(role="assistant", content=normalize_content(content))
75+
elif role == "assistant" and tool_calls is not None:
76+
message = ChatCompletionAssistantMessageParam(role="assistant", tool_calls=tool_calls)
77+
elif role == "tool" and tool_call_id is not None:
78+
message = ChatCompletionToolMessageParam(
79+
role="tool", tool_call_id=tool_call_id, content=normalize_content(content)
80+
)
6681
else:
67-
raise ValueError(f"Invalid role: {role}")
82+
raise ValueError("Invalid message for builder")
6883
self.messages.insert(index, message)
6984

7085

@@ -102,9 +117,17 @@ def build_messages(
102117
message_builder = _MessageBuilder(system_prompt)
103118

104119
for shot in reversed(few_shots):
105-
if shot["role"] is None or shot["content"] is None:
106-
raise ValueError("Few-shot messages must have both role and content")
107-
message_builder.insert_message(shot["role"], shot["content"])
120+
if shot["role"] is None or (shot.get("content") is None and shot.get("tool_calls") is None):
121+
raise ValueError("Few-shot messages must have role and either content or tool_calls")
122+
tool_call_id = shot.get("tool_call_id")
123+
if tool_call_id is not None and not isinstance(tool_call_id, str):
124+
raise ValueError("tool_call_id must be a string value")
125+
tool_calls = shot.get("tool_calls")
126+
if tool_calls is not None and not isinstance(tool_calls, Iterable):
127+
raise ValueError("tool_calls must be a list of tool calls")
128+
message_builder.insert_message(
129+
shot["role"], shot.get("content"), tool_calls=tool_calls, tool_call_id=tool_call_id
130+
)
108131

109132
append_index = len(few_shots)
110133

tests/test_messagebuilder.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,61 @@ def test_messagebuilder_system_fewshots():
200200
assert messages[5]["content"] == user_message_pm["message"]["content"]
201201

202202

203+
def test_messagebuilder_system_fewshotstools():
204+
messages = build_messages(
205+
model="gpt-35-turbo",
206+
system_prompt=system_message_short["message"]["content"],
207+
new_user_content=user_message_pm["message"]["content"],
208+
past_messages=[],
209+
few_shots=[
210+
{"role": "user", "content": "good options for climbing gear that can be used outside?"},
211+
{
212+
"role": "assistant",
213+
"tool_calls": [
214+
{
215+
"id": "call_abc123",
216+
"type": "function",
217+
"function": {
218+
"arguments": '{"search_query":"climbing gear outside"}',
219+
"name": "search_database",
220+
},
221+
}
222+
],
223+
},
224+
{
225+
"role": "tool",
226+
"tool_call_id": "call_abc123",
227+
"content": "Search results for climbing gear that can be used outside: ...",
228+
},
229+
{"role": "user", "content": "are there any shoes less than $50?"},
230+
{
231+
"role": "assistant",
232+
"tool_calls": [
233+
{
234+
"id": "call_abc456",
235+
"type": "function",
236+
"function": {
237+
"arguments": '{"search_query":"shoes","price_filter":{"comparison_operator":"<","value":50}}',
238+
"name": "search_database",
239+
},
240+
}
241+
],
242+
},
243+
{"role": "tool", "tool_call_id": "call_abc456", "content": "Search results for shoes cheaper than 50: ..."},
244+
],
245+
)
246+
# Make sure messages are in the right order
247+
assert messages[0]["role"] == "system"
248+
assert messages[1]["role"] == "user"
249+
assert messages[2]["role"] == "assistant"
250+
assert messages[3]["role"] == "tool"
251+
assert messages[4]["role"] == "user"
252+
assert messages[5]["role"] == "assistant"
253+
assert messages[6]["role"] == "tool"
254+
assert messages[7]["role"] == "user"
255+
assert messages[7]["content"] == user_message_pm["message"]["content"]
256+
257+
203258
def test_messagebuilder_system_tools():
204259
"""Tests that the system message token count is considered."""
205260
messages = build_messages(

0 commit comments

Comments
 (0)