Skip to content

Commit c630cbe

Browse files
authored
Merge pull request #10 from pamelafox/motypes
Use more precise types
2 parents db93ee5 + d49c0bf commit c630cbe

File tree

7 files changed

+102
-31
lines changed

7 files changed

+102
-31
lines changed

.github/workflows/python.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,6 @@ jobs:
3030
run: black . --check --verbose
3131
- name: Run unit tests
3232
run: |
33-
python3 -m pytest -s -vv --cov --cov-fail-under=99
33+
python3 -m pytest -s -vv --cov --cov-fail-under=98
3434
- name: Run type checks
3535
run: mypy .

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22

33
All notable changes to this project will be documented in this file.
44

5+
## [0.1.3] - May 2, 2024
6+
7+
- Use openai type annotations for more precise type hints, and add a typing test.
8+
59
## [0.1.2] - May 2, 2024
610

711
- Add `py.typed` file so that mypy can find the type hints in this package.

README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,10 @@ Arguments:
3232
* `model` (`str`): The model name to use for token calculation, like gpt-3.5-turbo.
3333
* `system_prompt` (`str`): The initial system prompt message.
3434
* `tools` (`List[openai.types.chat.ChatCompletionToolParam]`): (Optional) The tools that will be used in the conversation. These won't be part of the final returned messages, but they will be used to calculate the token count.
35-
* `tool_choice` (`str | dict`): (Optional) The tool choice that will be used in the conversation. This won't be part of the final returned messages, but it will be used to calculate the token count.
35+
* `tool_choice` (`openai.types.chat.ChatCompletionNamedToolChoiceParam`): (Optional) The tool choice that will be used in the conversation. This won't be part of the final returned messages, but it will be used to calculate the token count.
3636
* `new_user_content` (`str | List[openai.types.chat.ChatCompletionContentPartParam]`): (Optional) The content of new user message to append.
37-
* `past_messages` (`list[dict]`): (Optional) The list of past messages in the conversation.
38-
* `few_shots` (`list[dict]`): (Optional) A few-shot list of messages to insert after the system prompt.
37+
* `past_messages` (`list[openai.types.chat.ChatCompletionMessageParam]`): (Optional) The list of past messages in the conversation.
38+
* `few_shots` (`list[openai.types.chat.ChatCompletionMessageParam]`): (Optional) A few-shot list of messages to insert after the system prompt.
3939
* `max_tokens` (`int`): (Optional) The maximum number of tokens allowed for the conversation.
4040
* `fallback_to_default` (`bool`): (Optional) Whether to fallback to default model/token limits if model is not found. Defaults to `False`.
4141

@@ -83,7 +83,7 @@ Counts the number of tokens in a message.
8383
Arguments:
8484

8585
* `model` (`str`): The model name to use for token calculation, like gpt-3.5-turbo.
86-
* `message` (`dict`): The message to count tokens for.
86+
* `message` (`openai.types.chat.ChatCompletionMessageParam`): The message to count tokens for.
8787
* `default_to_cl100k` (`bool`): Whether to default to the CL100k token limit if the model is not found.
8888

8989
Returns:

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[project]
22
name = "openai-messages-token-helper"
33
description = "A helper library for estimating tokens used by messages sent through OpenAI Chat Completions API."
4-
version = "0.1.2"
4+
version = "0.1.3"
55
authors = [{name = "Pamela Fox"}]
66
requires-python = ">=3.9"
77
readme = "README.md"

src/openai_messages_token_helper/message_builder.py

Lines changed: 34 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,26 @@
11
import logging
22
import unicodedata
3+
from collections.abc import Iterable
34
from typing import Optional, Union
45

56
from openai.types.chat import (
67
ChatCompletionAssistantMessageParam,
78
ChatCompletionContentPartParam,
89
ChatCompletionMessageParam,
10+
ChatCompletionNamedToolChoiceParam,
11+
ChatCompletionRole,
912
ChatCompletionSystemMessageParam,
13+
ChatCompletionToolParam,
1014
ChatCompletionUserMessageParam,
1115
)
1216

1317
from .model_helper import count_tokens_for_message, count_tokens_for_system_and_tools, get_token_limit
1418

1519

16-
def normalize_content(content: Union[str, list[ChatCompletionContentPartParam]]):
20+
def normalize_content(content: Union[str, Iterable[ChatCompletionContentPartParam]]):
1721
if isinstance(content, str):
1822
return unicodedata.normalize("NFC", content)
19-
elif isinstance(content, list):
23+
else:
2024
for part in content:
2125
if "image_url" not in part:
2226
part["text"] = unicodedata.normalize("NFC", part["text"])
@@ -36,14 +40,19 @@ class _MessageBuilder:
3640
"""
3741

3842
def __init__(self, system_content: str):
39-
self.messages: list[ChatCompletionMessageParam] = [
40-
ChatCompletionSystemMessageParam(role="system", content=normalize_content(system_content))
41-
]
43+
self.system_message = ChatCompletionSystemMessageParam(role="system", content=normalize_content(system_content))
44+
self.messages: list[ChatCompletionMessageParam] = []
4245

43-
def insert_message(self, role: str, content: Union[str, list[ChatCompletionContentPartParam]], index: int = 1):
46+
@property
47+
def all_messages(self) -> list[ChatCompletionMessageParam]:
48+
return [self.system_message] + self.messages
49+
50+
def insert_message(
51+
self, role: ChatCompletionRole, content: Union[str, Iterable[ChatCompletionContentPartParam]], index: int = 0
52+
):
4453
"""
4554
Inserts a message into the conversation at the specified index,
46-
or at index 1 (after system message) if no index is specified.
55+
or at index 0 if no index is specified.
4756
Args:
4857
role (str): The role of the message sender (either "user", "system", or "assistant").
4958
content (str | List[ChatCompletionContentPartParam]): The content of the message.
@@ -63,11 +72,11 @@ def build_messages(
6372
model: str,
6473
system_prompt: str,
6574
*,
66-
tools: Optional[list[dict[str, dict]]] = None,
67-
tool_choice: Optional[Union[str, dict]] = None,
75+
tools: Optional[list[ChatCompletionToolParam]] = None,
76+
tool_choice: Optional[ChatCompletionNamedToolChoiceParam] = None,
6877
new_user_content: Union[str, list[ChatCompletionContentPartParam], None] = None, # list is for GPT4v usage
69-
past_messages: list[dict[str, str]] = [], # *not* including system prompt
70-
few_shots=[], # will always be inserted after system prompt
78+
past_messages: list[ChatCompletionMessageParam] = [], # *not* including system prompt
79+
few_shots: list[ChatCompletionMessageParam] = [], # will always be inserted after system prompt
7180
max_tokens: Optional[int] = None,
7281
fallback_to_default: bool = False,
7382
) -> list[ChatCompletionMessageParam]:
@@ -78,11 +87,11 @@ def build_messages(
7887
Args:
7988
model (str): The model name to use for token calculation, like gpt-3.5-turbo.
8089
system_prompt (str): The initial system prompt message.
81-
tools (list[dict]): A list of tools to include in the conversation.
82-
tool_choice (str | dict): The tool to use in the conversation.
90+
tools (list[ChatCompletionToolParam]): A list of tools to include in the conversation.
91+
tool_choice (ChatCompletionNamedToolChoiceParam): The tool to use in the conversation.
8392
new_user_content (str | List[ChatCompletionContentPartParam]): Content of new user message to append.
84-
past_messages (list[dict]): The list of past messages in the conversation.
85-
few_shots (list[dict]): A few-shot list of messages to insert after the system prompt.
93+
past_messages (list[ChatCompletionMessageParam]): The list of past messages in the conversation.
94+
few_shots (list[ChatCompletionMessageParam]): A few-shot list of messages to insert after the system prompt.
8695
max_tokens (int): The maximum number of tokens allowed for the conversation.
8796
fallback_to_default (bool): Whether to fallback to default model if the model is not found.
8897
"""
@@ -93,17 +102,19 @@ def build_messages(
93102
message_builder = _MessageBuilder(system_prompt)
94103

95104
for shot in reversed(few_shots):
96-
message_builder.insert_message(shot.get("role"), shot.get("content"))
105+
if shot["role"] is None or shot["content"] is None:
106+
raise ValueError("Few-shot messages must have both role and content")
107+
message_builder.insert_message(shot["role"], shot["content"])
97108

98-
append_index = len(few_shots) + 1
109+
append_index = len(few_shots)
99110

100111
if new_user_content:
101112
message_builder.insert_message("user", new_user_content, index=append_index)
102113

103114
total_token_count = count_tokens_for_system_and_tools(
104-
model, message_builder.messages[0], tools, tool_choice, default_to_cl100k=fallback_to_default
115+
model, message_builder.system_message, tools, tool_choice, default_to_cl100k=fallback_to_default
105116
)
106-
for existing_message in message_builder.messages[1:]:
117+
for existing_message in message_builder.messages:
107118
total_token_count += count_tokens_for_message(model, existing_message, default_to_cl100k=fallback_to_default)
108119

109120
newest_to_oldest = list(reversed(past_messages))
@@ -112,6 +123,9 @@ def build_messages(
112123
if (total_token_count + potential_message_count) > max_tokens:
113124
logging.info("Reached max tokens of %d, history will be truncated", max_tokens)
114125
break
126+
127+
if message["role"] is None or message["content"] is None:
128+
raise ValueError("Few-shot messages must have both role and content")
115129
message_builder.insert_message(message["role"], message["content"], index=append_index)
116130
total_token_count += potential_message_count
117-
return message_builder.messages
131+
return message_builder.all_messages

src/openai_messages_token_helper/model_helper.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,14 @@
11
from __future__ import annotations
22

33
import logging
4-
from collections.abc import Mapping
54

65
import tiktoken
6+
from openai.types.chat import (
7+
ChatCompletionMessageParam,
8+
ChatCompletionNamedToolChoiceParam,
9+
ChatCompletionSystemMessageParam,
10+
ChatCompletionToolParam,
11+
)
712

813
from .function_format import format_function_definitions
914
from .images_helper import count_tokens_for_image
@@ -69,7 +74,7 @@ def encoding_for_model(model: str, default_to_cl100k=False) -> tiktoken.Encoding
6974
raise
7075

7176

72-
def count_tokens_for_message(model: str, message: Mapping[str, object], default_to_cl100k=False) -> int:
77+
def count_tokens_for_message(model: str, message: ChatCompletionMessageParam, default_to_cl100k=False) -> int:
7378
"""
7479
Calculate the number of tokens required to encode a message. Based off cookbook:
7580
https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
@@ -113,9 +118,9 @@ def count_tokens_for_message(model: str, message: Mapping[str, object], default_
113118

114119
def count_tokens_for_system_and_tools(
115120
model: str,
116-
system_message: Mapping[str, object] | None = None,
117-
tools: list[dict[str, dict]] | None = None,
118-
tool_choice: str | dict | None = None,
121+
system_message: ChatCompletionSystemMessageParam | None = None,
122+
tools: list[ChatCompletionToolParam] | None = None,
123+
tool_choice: ChatCompletionNamedToolChoiceParam | None = None,
119124
default_to_cl100k: bool = False,
120125
) -> int:
121126
"""

tests/test_messagebuilder.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1+
import typing
2+
13
import pytest
4+
from openai.types.chat import ChatCompletionMessageParam, ChatCompletionNamedToolChoiceParam, ChatCompletionToolParam
25
from openai_messages_token_helper import build_messages, count_tokens_for_message
36

47
from .functions import search_sources_toolchoice_auto
@@ -213,3 +216,48 @@ def test_messagebuilder_system_tools():
213216
max_tokens=90,
214217
)
215218
assert messages == [search_sources_toolchoice_auto["system_message"], user_message_pm["message"]]
219+
220+
221+
def test_messagebuilder_typing() -> None:
222+
tools: list[ChatCompletionToolParam] = [
223+
{
224+
"type": "function",
225+
"function": {
226+
"name": "search_sources",
227+
"description": "Retrieve sources from the Azure AI Search index",
228+
"parameters": {
229+
"type": "object",
230+
"properties": {
231+
"search_query": {
232+
"type": "string",
233+
"description": "Query string to retrieve documents from azure search eg: 'Health care plan'",
234+
}
235+
},
236+
"required": ["search_query"],
237+
},
238+
},
239+
}
240+
]
241+
tool_choice: ChatCompletionNamedToolChoiceParam = {
242+
"type": "function",
243+
"function": {"name": "search_sources"},
244+
}
245+
246+
past_messages: list[ChatCompletionMessageParam] = [
247+
{"role": "user", "content": "What are my health plans?"},
248+
{"role": "assistant", "content": "Here are some tools you can use to search for sources."},
249+
]
250+
251+
messages = build_messages(
252+
model="gpt-35-turbo",
253+
system_prompt="Here are some tools you can use to search for sources.",
254+
tools=tools,
255+
tool_choice=tool_choice,
256+
past_messages=past_messages,
257+
new_user_content="What are my health plans?",
258+
max_tokens=90,
259+
)
260+
261+
assert isinstance(messages, list)
262+
if hasattr(typing, "assert_type"):
263+
typing.assert_type(messages[0], ChatCompletionMessageParam)

0 commit comments

Comments
 (0)