Skip to content

Commit a840ee4

Browse files
committed
Add arg to message_builder
1 parent 6626140 commit a840ee4

File tree

5 files changed

+35
-9
lines changed

5 files changed

+35
-9
lines changed

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,11 @@
22

33
All notable changes to this project will be documented in this file.
44

5+
## [0.0.6] - April 24, 2024
6+
7+
- Add keyword argument `fallback_to_default` to `build_messages` function to allow for defaulting to the CL100k token encoder and minimum GPT token limit if the model is not found.
8+
- Fixed usage of `past_messages` argument of `build_messages` to not skip the last past message. (New user message should *not* be passed in)
9+
510
## [0.0.5] - April 24, 2024
611

712
- Add keyword argument `default_to_cl100k` to `count_tokens_for_message` function to allow for defaulting to the CL100k token limit if the model is not found.

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ Arguments:
3535
* `past_messages` (`list[dict]`): The list of past messages in the conversation.
3636
* `few_shots` (`list[dict]`): A few-shot list of messages to insert after the system prompt.
3737
* `max_tokens` (`int`): The maximum number of tokens allowed for the conversation.
38+
* `fallback_to_default` (`bool`): Whether to fallback to default model/token limits if model is not found. Defaults to `False`.
3839

3940
Returns:
4041

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[project]
22
name = "openai-messages-token-helper"
33
description = "A helper library for estimating tokens used by messages sent through OpenAI Chat Completions API."
4-
version = "0.0.5"
4+
version = "0.0.6"
55
authors = [{name = "Pamela Fox"}]
66
requires-python = ">=3.9"
77
readme = "README.md"

src/openai_messages_token_helper/message_builder.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import logging
22
import unicodedata
3-
from collections.abc import Mapping
43
from typing import Optional, Union
54

65
from openai.types.chat import (
@@ -52,9 +51,6 @@ def insert_message(self, role: str, content: Union[str, list[ChatCompletionConte
5251
raise ValueError(f"Invalid role: {role}")
5352
self.messages.insert(index, message)
5453

55-
def count_tokens_for_message(self, message: Mapping[str, object]):
56-
return count_tokens_for_message(self.model, message)
57-
5854
def normalize_content(self, content: Union[str, list[ChatCompletionContentPartParam]]):
5955
if isinstance(content, str):
6056
return unicodedata.normalize("NFC", content)
@@ -72,6 +68,7 @@ def build_messages(
7268
past_messages: list[dict[str, str]] = [], # *not* including system prompt
7369
few_shots=[], # will always be inserted after system prompt
7470
max_tokens: Optional[int] = None,
71+
fallback_to_default: bool = False,
7572
) -> list[ChatCompletionMessageParam]:
7673
"""
7774
Build a list of messages for a chat conversation, given the system prompt, new user message,
@@ -84,10 +81,11 @@ def build_messages(
8481
past_messages (list[dict]): The list of past messages in the conversation.
8582
few_shots (list[dict]): A few-shot list of messages to insert after the system prompt.
8683
max_tokens (int): The maximum number of tokens allowed for the conversation.
84+
fallback_to_default (bool): Whether to fallback to default model if the model is not found.
8785
"""
8886
message_builder = MessageBuilder(system_prompt, model)
8987
if max_tokens is None:
90-
max_tokens = get_token_limit(model)
88+
max_tokens = get_token_limit(model, default_to_minimum=fallback_to_default)
9189

9290
for shot in reversed(few_shots):
9391
message_builder.insert_message(shot.get("role"), shot.get("content"))
@@ -99,11 +97,11 @@ def build_messages(
9997

10098
total_token_count = 0
10199
for existing_message in message_builder.messages:
102-
total_token_count += message_builder.count_tokens_for_message(existing_message)
100+
total_token_count += count_tokens_for_message(model, existing_message, default_to_cl100k=fallback_to_default)
103101

104-
newest_to_oldest = list(reversed(past_messages[:-1]))
102+
newest_to_oldest = list(reversed(past_messages))
105103
for message in newest_to_oldest:
106-
potential_message_count = message_builder.count_tokens_for_message(message)
104+
potential_message_count = count_tokens_for_message(model, message, default_to_cl100k=fallback_to_default)
107105
if (total_token_count + potential_message_count) > max_tokens:
108106
logging.info("Reached max tokens of %d, history will be truncated", max_tokens)
109107
break

tests/test_messagebuilder.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import pytest
12
from openai_messages_token_helper import build_messages, count_tokens_for_message
23

34
from .messages import system_message_short, system_message_unicode, user_message, user_message_unicode
@@ -33,3 +34,24 @@ def test_messagebuilder_unicode_append():
3334
assert messages == [system_message_unicode["message"], user_message_unicode["message"]]
3435
assert count_tokens_for_message("gpt-35-turbo", messages[0]) == system_message_unicode["count"]
3536
assert count_tokens_for_message("gpt-35-turbo", messages[1]) == user_message_unicode["count"]
37+
38+
39+
def test_messagebuilder_model_error():
40+
model = "phi-3"
41+
with pytest.raises(ValueError, match="Called with unknown model name: phi-3"):
42+
build_messages(
43+
model, system_message_short["message"]["content"], new_user_message=user_message["message"]["content"]
44+
)
45+
46+
47+
def test_messagebuilder_model_fallback():
48+
model = "phi-3"
49+
messages = build_messages(
50+
model,
51+
system_message_short["message"]["content"],
52+
new_user_message=user_message["message"]["content"],
53+
fallback_to_default=True,
54+
)
55+
assert messages == [system_message_short["message"], user_message["message"]]
56+
assert count_tokens_for_message(model, messages[0], default_to_cl100k=True) == system_message_short["count"]
57+
assert count_tokens_for_message(model, messages[1], default_to_cl100k=True) == user_message["count"]

0 commit comments

Comments
 (0)