Skip to content

Commit 648a39c

Browse files
committed
Defaults for encoders and limits
1 parent 06ce38e commit 648a39c

File tree

4 files changed

+53
-21
lines changed

4 files changed

+53
-21
lines changed

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,11 @@
22

33
All notable changes to this project will be documented in this file.
44

5+
## [0.0.5] - April 24, 2024
6+
7+
- Add keyword argument `default_to_cl100k` to `count_tokens_for_message` function to allow for defaulting to the CL100k token limit if the model is not found.
8+
- Add keyword argument `default_to_minimum` to `get_token_limit` function to allow for defaulting to the minimum token limit if the model is not found.
9+
510
## [0.0.4] - April 21, 2024
611

712
- Rename to openai-messages-token-helper from llm-messages-token-helper to reflect library's current OpenAI focus.

README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,8 @@ Counts the number of tokens in a message.
7979
Arguments:
8080

8181
* `model` (`str`): The model name to use for token calculation, like gpt-3.5-turbo.
82+
* `message` (`dict`): The message to count tokens for.
83+
* `default_to_cl100k` (`bool`): Whether to default to the CL100k token limit if the model is not found.
8284

8385
Returns:
8486

@@ -129,6 +131,7 @@ Get the token limit for a given GPT model name (OpenAI.com or Azure OpenAI suppo
129131
Arguments:
130132

131133
* `model` (`str`): The model name to use for token calculation, like gpt-3.5-turbo (OpenAI.com) or gpt-35-turbo (Azure).
134+
* `default_to_minimum` (`bool`): Whether to default to the minimum token limit if the model is not found.
132135

133136
Returns:
134137

src/openai_messages_token_helper/model_helper.py

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import logging
34
from collections.abc import Mapping
45

56
import tiktoken
@@ -19,28 +20,37 @@
1920

2021
AOAI_2_OAI = {"gpt-35-turbo": "gpt-3.5-turbo", "gpt-35-turbo-16k": "gpt-3.5-turbo-16k", "gpt-4v": "gpt-4-turbo-vision"}
2122

23+
logger = logging.getLogger("openai_messages_token_helper")
2224

23-
def get_token_limit(model: str) -> int:
25+
26+
def get_token_limit(model: str, default_to_minimum=False) -> int:
2427
"""
2528
Get the token limit for a given GPT model name (OpenAI.com or Azure OpenAI supported).
2629
Args:
2730
model (str): The name of the model to get the token limit for.
31+
default_to_minimum (bool): Whether to default to the minimum token limit if the model is not found.
2832
Returns:
2933
int: The token limit for the model.
3034
"""
3135
if model not in MODELS_2_TOKEN_LIMITS:
32-
raise ValueError(f"Called with unknown model name: {model}")
36+
if default_to_minimum:
37+
min_token_limit = min(MODELS_2_TOKEN_LIMITS.values())
38+
logger.warning("Model %s not found, defaulting to minimum token limit %d", model, min_token_limit)
39+
return min_token_limit
40+
else:
41+
raise ValueError(f"Called with unknown model name: {model}")
3342
return MODELS_2_TOKEN_LIMITS[model]
3443

3544

36-
def count_tokens_for_message(model: str, message: Mapping[str, object]) -> int:
45+
def count_tokens_for_message(model: str, message: Mapping[str, object], default_to_cl100k=False) -> int:
3746
"""
3847
Calculate the number of tokens required to encode a message. Based off cookbook:
3948
https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
4049
4150
Args:
4251
model (str): The name of the model to use for encoding.
4352
message (Mapping): The message to encode, in a dictionary-like object.
53+
default_to_cl100k (bool): Whether to default to the CL100k encoding if the model is not found.
4454
Returns:
4555
int: The total number of tokens required to encode the message.
4656
@@ -49,8 +59,22 @@ def count_tokens_for_message(model: str, message: Mapping[str, object]) -> int:
4959
>> count_tokens_for_message(model, message)
5060
13
5161
"""
62+
if (
63+
model == ""
64+
or model is None
65+
or (model not in AOAI_2_OAI and model not in MODELS_2_TOKEN_LIMITS and not default_to_cl100k)
66+
):
67+
raise ValueError("Expected valid OpenAI GPT model name")
68+
model = AOAI_2_OAI.get(model, model)
69+
try:
70+
encoding = tiktoken.encoding_for_model(model)
71+
except KeyError:
72+
if default_to_cl100k:
73+
logger.warning("Model %s not found, defaulting to CL100k encoding", model)
74+
encoding = tiktoken.get_encoding("cl100k_base")
75+
else:
76+
raise
5277

53-
encoding = tiktoken.encoding_for_model(get_oai_chatmodel_tiktok(model))
5478
# Assumes we're using a recent model
5579
tokens_per_message = 3
5680

@@ -72,12 +96,3 @@ def count_tokens_for_message(model: str, message: Mapping[str, object]) -> int:
7296
num_tokens += 1
7397
num_tokens += 3 # every reply is primed with <|start|>assistant<|message|>
7498
return num_tokens
75-
76-
77-
def get_oai_chatmodel_tiktok(aoaimodel: str) -> str:
78-
message = "Expected valid OpenAI GPT model name"
79-
if aoaimodel == "" or aoaimodel is None:
80-
raise ValueError(message)
81-
if aoaimodel not in AOAI_2_OAI and aoaimodel not in MODELS_2_TOKEN_LIMITS:
82-
raise ValueError(message)
83-
return AOAI_2_OAI.get(aoaimodel, aoaimodel)

tests/test_modelhelper.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,12 @@ def test_get_token_limit_error():
1818
get_token_limit("gpt-3")
1919

2020

21+
def test_get_token_limit_default(caplog):
22+
with caplog.at_level("WARNING"):
23+
assert get_token_limit("gpt-3", default_to_minimum=True) == 4000
24+
assert "Model gpt-3 not found, defaulting to minimum token limit 4000" in caplog.text
25+
26+
2127
# parameterize the model and the expected number of tokens
2228
@pytest.mark.parametrize(
2329
"model",
@@ -58,14 +64,17 @@ def test_count_tokens_for_message_error():
5864
count_tokens_for_message(model, message)
5965

6066

61-
def test_get_oai_chatmodel_tiktok_error():
62-
message = {
63-
"role": "user",
64-
"content": "hello",
65-
}
67+
def test_count_tokens_for_message_model_error():
6668
with pytest.raises(ValueError, match="Expected valid OpenAI GPT model name"):
67-
count_tokens_for_message("", message)
69+
count_tokens_for_message("", user_message["message"])
6870
with pytest.raises(ValueError, match="Expected valid OpenAI GPT model name"):
69-
count_tokens_for_message(None, message)
71+
count_tokens_for_message(None, user_message["message"])
7072
with pytest.raises(ValueError, match="Expected valid OpenAI GPT model name"):
71-
count_tokens_for_message("gpt44", message)
73+
count_tokens_for_message("gpt44", user_message["message"])
74+
75+
76+
def test_count_tokens_for_message_model_default(caplog):
77+
model = "phi-3"
78+
with caplog.at_level("WARNING"):
79+
assert count_tokens_for_message(model, user_message["message"], default_to_cl100k=True) == user_message["count"]
80+
assert "Model phi-3 not found, defaulting to CL100k encoding" in caplog.text

0 commit comments

Comments
 (0)