Skip to content

Commit d79be12

Browse files
authored
Merge pull request #15 from pamelafox/toolchoice
Fix tool choice type
2 parents 71f9bfe + 1d5ea86 commit d79be12

File tree

6 files changed

+31
-9
lines changed

6 files changed

+31
-9
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22

33
All notable changes to this project will be documented in this file.
44

5+
## [0.1.8] - Aug 3, 2024
6+
7+
- Fix the type for the tool_choice param to be inclusive of "auto" and other options.
8+
59
## [0.1.7] - Aug 3, 2024
610

711
- Fix bug where you couldn't pass in example tool calls in `few_shots` to `build_messages`.

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ Arguments:
3232
* `model` (`str`): The model name to use for token calculation, like gpt-3.5-turbo.
3333
* `system_prompt` (`str`): The initial system prompt message.
3434
* `tools` (`List[openai.types.chat.ChatCompletionToolParam]`): (Optional) The tools that will be used in the conversation. These won't be part of the final returned messages, but they will be used to calculate the token count.
35-
* `tool_choice` (`openai.types.chat.ChatCompletionNamedToolChoiceParam`): (Optional) The tool choice that will be used in the conversation. This won't be part of the final returned messages, but it will be used to calculate the token count.
35+
* `tool_choice` (`openai.types.chat.ChatCompletionToolChoiceOptionParam`): (Optional) The tool choice that will be used in the conversation. This won't be part of the final returned messages, but it will be used to calculate the token count.
3636
* `new_user_content` (`str | List[openai.types.chat.ChatCompletionContentPartParam]`): (Optional) The content of new user message to append.
3737
* `past_messages` (`list[openai.types.chat.ChatCompletionMessageParam]`): (Optional) The list of past messages in the conversation.
3838
* `few_shots` (`list[openai.types.chat.ChatCompletionMessageParam]`): (Optional) A few-shot list of messages to insert after the system prompt.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[project]
22
name = "openai-messages-token-helper"
33
description = "A helper library for estimating tokens used by messages sent through OpenAI Chat Completions API."
4-
version = "0.1.7"
4+
version = "0.1.8"
55
authors = [{name = "Pamela Fox"}]
66
requires-python = ">=3.9"
77
readme = "README.md"

src/openai_messages_token_helper/message_builder.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@
88
ChatCompletionContentPartParam,
99
ChatCompletionMessageParam,
1010
ChatCompletionMessageToolCallParam,
11-
ChatCompletionNamedToolChoiceParam,
1211
ChatCompletionRole,
1312
ChatCompletionSystemMessageParam,
13+
ChatCompletionToolChoiceOptionParam,
1414
ChatCompletionToolMessageParam,
1515
ChatCompletionToolParam,
1616
ChatCompletionUserMessageParam,
@@ -88,7 +88,7 @@ def build_messages(
8888
system_prompt: str,
8989
*,
9090
tools: Optional[list[ChatCompletionToolParam]] = None,
91-
tool_choice: Optional[ChatCompletionNamedToolChoiceParam] = None,
91+
tool_choice: Optional[ChatCompletionToolChoiceOptionParam] = None,
9292
new_user_content: Union[str, list[ChatCompletionContentPartParam], None] = None, # list is for GPT4v usage
9393
past_messages: list[ChatCompletionMessageParam] = [], # *not* including system prompt
9494
few_shots: list[ChatCompletionMessageParam] = [], # will always be inserted after system prompt
@@ -103,7 +103,7 @@ def build_messages(
103103
model (str): The model name to use for token calculation, like gpt-3.5-turbo.
104104
system_prompt (str): The initial system prompt message.
105105
tools (list[ChatCompletionToolParam]): A list of tools to include in the conversation.
106-
tool_choice (ChatCompletionNamedToolChoiceParam): The tool to use in the conversation.
106+
tool_choice (ChatCompletionToolChoiceOptionParam): The tool to use in the conversation.
107107
new_user_content (str | List[ChatCompletionContentPartParam]): Content of new user message to append.
108108
past_messages (list[ChatCompletionMessageParam]): The list of past messages in the conversation.
109109
few_shots (list[ChatCompletionMessageParam]): A few-shot list of messages to insert after the system prompt.

src/openai_messages_token_helper/model_helper.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
import tiktoken
66
from openai.types.chat import (
77
ChatCompletionMessageParam,
8-
ChatCompletionNamedToolChoiceParam,
98
ChatCompletionSystemMessageParam,
9+
ChatCompletionToolChoiceOptionParam,
1010
ChatCompletionToolParam,
1111
)
1212

@@ -121,7 +121,7 @@ def count_tokens_for_system_and_tools(
121121
model: str,
122122
system_message: ChatCompletionSystemMessageParam | None = None,
123123
tools: list[ChatCompletionToolParam] | None = None,
124-
tool_choice: ChatCompletionNamedToolChoiceParam | None = None,
124+
tool_choice: ChatCompletionToolChoiceOptionParam | None = None,
125125
default_to_cl100k: bool = False,
126126
) -> int:
127127
"""

tests/test_messagebuilder.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
import typing
22

33
import pytest
4-
from openai.types.chat import ChatCompletionMessageParam, ChatCompletionNamedToolChoiceParam, ChatCompletionToolParam
4+
from openai.types.chat import (
5+
ChatCompletionMessageParam,
6+
ChatCompletionToolChoiceOptionParam,
7+
ChatCompletionToolParam,
8+
)
59
from openai_messages_token_helper import build_messages, count_tokens_for_message
610

711
from .functions import search_sources_toolchoice_auto
@@ -293,7 +297,7 @@ def test_messagebuilder_typing() -> None:
293297
},
294298
}
295299
]
296-
tool_choice: ChatCompletionNamedToolChoiceParam = {
300+
tool_choice: ChatCompletionToolChoiceOptionParam = {
297301
"type": "function",
298302
"function": {"name": "search_sources"},
299303
}
@@ -316,3 +320,17 @@ def test_messagebuilder_typing() -> None:
316320
assert isinstance(messages, list)
317321
if hasattr(typing, "assert_type"):
318322
typing.assert_type(messages[0], ChatCompletionMessageParam)
323+
324+
messages = build_messages(
325+
model="gpt-35-turbo",
326+
system_prompt="Here are some tools you can use to search for sources.",
327+
tools=tools,
328+
tool_choice="auto",
329+
past_messages=past_messages,
330+
new_user_content="What are my health plans?",
331+
max_tokens=90,
332+
)
333+
334+
assert isinstance(messages, list)
335+
if hasattr(typing, "assert_type"):
336+
typing.assert_type(messages[0], ChatCompletionMessageParam)

0 commit comments

Comments
 (0)