Skip to content

Commit d7f0b87

Browse files
committed
add token counting
1 parent 359c6d2 commit d7f0b87

File tree

5 files changed

+198
-29
lines changed

5 files changed

+198
-29
lines changed

pydantic_ai_slim/pydantic_ai/_utils.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from types import GenericAlias
1515
from typing import TYPE_CHECKING, Any, Generic, TypeAlias, TypeGuard, TypeVar, get_args, get_origin, overload
1616

17+
import tiktoken
1718
from anyio.to_thread import run_sync
1819
from pydantic import BaseModel, TypeAdapter
1920
from pydantic.json_schema import JsonSchemaValue
@@ -32,10 +33,14 @@
3233
AbstractSpan = AbstractSpan
3334

3435
if TYPE_CHECKING:
36+
from openai.types.chat import ChatCompletionMessageParam
37+
from openai.types.responses.response_input_item_param import ResponseInputItemParam
38+
3539
from pydantic_ai.agent import AgentRun, AgentRunResult
3640
from pydantic_graph import GraphRun, GraphRunResult
3741

3842
from . import messages as _messages
43+
from .models.openai import OpenAIModelName
3944
from .tools import ObjectJsonSchema
4045

4146
_P = ParamSpec('_P')
@@ -507,3 +512,48 @@ def get_event_loop():
507512
event_loop = asyncio.new_event_loop()
508513
asyncio.set_event_loop(event_loop)
509514
return event_loop
515+
516+
517+
def num_tokens_from_messages(
518+
messages: list[ChatCompletionMessageParam] | list[ResponseInputItemParam],
519+
model: OpenAIModelName = 'gpt-4o-mini-2024-07-18',
520+
) -> int:
521+
"""Return the number of tokens used by a list of messages."""
522+
try:
523+
encoding = tiktoken.encoding_for_model(model)
524+
except KeyError:
525+
print('Warning: model not found. Using o200k_base encoding.') # TODO: How to handle warnings?
526+
encoding = tiktoken.get_encoding('o200k_base')
527+
if model in {
528+
'gpt-3.5-turbo-0125',
529+
'gpt-4-0314',
530+
'gpt-4-32k-0314',
531+
'gpt-4-0613',
532+
'gpt-4-32k-0613',
533+
'gpt-4o-mini-2024-07-18',
534+
'gpt-4o-2024-08-06',
535+
}:
536+
tokens_per_message = 3
537+
tokens_per_name = 1
538+
elif 'gpt-3.5-turbo' in model:
539+
return num_tokens_from_messages(messages, model='gpt-3.5-turbo-0125')
540+
elif 'gpt-4o-mini' in model:
541+
return num_tokens_from_messages(messages, model='gpt-4o-mini-2024-07-18')
542+
elif 'gpt-4o' in model:
543+
return num_tokens_from_messages(messages, model='gpt-4o-2024-08-06')
544+
elif 'gpt-4' in model:
545+
return num_tokens_from_messages(messages, model='gpt-4-0613')
546+
else:
547+
raise NotImplementedError(
548+
f"""num_tokens_from_messages() is not implemented for model {model}."""
549+
) # TODO: How to handle other models?
550+
num_tokens = 0
551+
for message in messages:
552+
num_tokens += tokens_per_message
553+
for key, value in message.items():
554+
if isinstance(value, str):
555+
num_tokens += len(encoding.encode(value))
556+
if key == 'name':
557+
num_tokens += tokens_per_name
558+
num_tokens += 3 # every reply is primed with <|start|>assistant<|message|>
559+
return num_tokens

pydantic_ai_slim/pydantic_ai/models/openai.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,12 @@
1717
from .._output import DEFAULT_OUTPUT_TOOL_NAME, OutputObjectDefinition
1818
from .._run_context import RunContext
1919
from .._thinking_part import split_content_into_text_and_thinking
20-
from .._utils import guard_tool_call_id as _guard_tool_call_id, now_utc as _now_utc, number_to_datetime
20+
from .._utils import (
21+
guard_tool_call_id as _guard_tool_call_id,
22+
now_utc as _now_utc,
23+
num_tokens_from_messages,
24+
number_to_datetime,
25+
)
2126
from ..builtin_tools import CodeExecutionTool, ImageGenerationTool, MCPServerTool, WebSearchTool
2227
from ..exceptions import UserError
2328
from ..messages import (
@@ -907,6 +912,20 @@ def _inline_text_file_part(text: str, *, media_type: str, identifier: str) -> Ch
907912
)
908913
return ChatCompletionContentPartTextParam(text=text, type='text')
909914

915+
async def count_tokens(
916+
self,
917+
messages: list[ModelMessage],
918+
model_settings: ModelSettings | None,
919+
model_request_parameters: ModelRequestParameters,
920+
) -> usage.RequestUsage:
921+
"""Make a request to the model for counting tokens."""
922+
openai_messages = await self._map_messages(messages, model_request_parameters)
923+
token_count = num_tokens_from_messages(openai_messages, self.model_name)
924+
925+
return usage.RequestUsage(
926+
input_tokens=token_count,
927+
)
928+
910929

911930
@deprecated(
912931
'`OpenAIModel` was renamed to `OpenAIChatModel` to clearly distinguish it from `OpenAIResponsesModel` which '
@@ -1701,6 +1720,22 @@ async def _map_user_prompt(part: UserPromptPart) -> responses.EasyInputMessagePa
17011720
assert_never(item)
17021721
return responses.EasyInputMessageParam(role='user', content=content)
17031722

1723+
async def count_tokens(
1724+
self,
1725+
messages: list[ModelMessage],
1726+
model_settings: ModelSettings | None,
1727+
model_request_parameters: ModelRequestParameters,
1728+
) -> usage.RequestUsage:
1729+
"""Make a request to the model for counting tokens."""
1730+
_, openai_messages = await self._map_messages(
1731+
messages, cast(OpenAIResponsesModelSettings, model_settings or {}), model_request_parameters
1732+
)
1733+
token_count = num_tokens_from_messages(openai_messages, self.model_name)
1734+
1735+
return usage.RequestUsage(
1736+
input_tokens=token_count,
1737+
)
1738+
17041739

17051740
@dataclass
17061741
class OpenAIStreamedResponse(StreamedResponse):

pydantic_ai_slim/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ dependencies = [
6767
# WARNING if you add optional groups, please update docs/install.md
6868
logfire = ["logfire[httpx]>=3.14.1"]
6969
# Models
70-
openai = ["openai>=1.107.2"]
70+
openai = ["openai>=1.107.2","tiktoken>=0.12.0"]
7171
cohere = ["cohere>=5.18.0; platform_system != 'Emscripten'"]
7272
vertexai = ["google-auth>=2.36.0", "requests>=2.32.2"]
7373
google = ["google-genai>=1.50.1"]

tests/models/test_openai.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
DocumentUrl,
2222
ImageUrl,
2323
ModelHTTPError,
24+
ModelMessage,
2425
ModelProfile,
2526
ModelRequest,
2627
ModelResponse,
@@ -3085,3 +3086,57 @@ async def test_cache_point_filtering_responses_model():
30853086
assert len(msg['content']) == 2
30863087
assert msg['content'][0]['text'] == 'text before' # type: ignore[reportUnknownArgumentType]
30873088
assert msg['content'][1]['text'] == 'text after' # type: ignore[reportUnknownArgumentType]
3089+
3090+
3091+
@pytest.mark.vcr()
3092+
@pytest.mark.parametrize(
3093+
'model_name,expected_token_count',
3094+
[
3095+
('gpt-3.5-turbo', 115),
3096+
('gpt-4-0613', 115),
3097+
('gpt-4', 115),
3098+
('gpt-4o', 110),
3099+
('gpt-4o-mini', 110),
3100+
],
3101+
)
3102+
async def test_count_tokens(model_name: str, expected_token_count: int):
3103+
"""Test token counting with OpenAI Chat and Response models."""
3104+
test_messages: list[ModelMessage] = [
3105+
ModelRequest(
3106+
parts=[
3107+
SystemPromptPart(
3108+
content='You are a helpful, pattern-following assistant that translates corporate jargon into plain English.',
3109+
timestamp=IsNow(tz=timezone.utc),
3110+
),
3111+
SystemPromptPart(
3112+
content='New synergies will help drive top-line growth.',
3113+
timestamp=IsNow(tz=timezone.utc),
3114+
),
3115+
SystemPromptPart(
3116+
content='Things working well together will increase revenue.',
3117+
timestamp=IsNow(tz=timezone.utc),
3118+
),
3119+
SystemPromptPart(
3120+
content="Let's circle back when we have more bandwidth to touch base on opportunities for increased leverage.",
3121+
timestamp=IsNow(tz=timezone.utc),
3122+
),
3123+
SystemPromptPart(
3124+
content="Let's talk later when we're less busy about how to do better.",
3125+
timestamp=IsNow(tz=timezone.utc),
3126+
),
3127+
UserPromptPart(
3128+
content="This late pivot means we don't have time to boil the ocean for the client deliverable.",
3129+
timestamp=IsNow(tz=timezone.utc),
3130+
),
3131+
],
3132+
run_id=IsStr(),
3133+
)
3134+
]
3135+
3136+
chat_model = OpenAIChatModel(model_name, provider=OpenAIProvider(api_key='foobar'))
3137+
usage_result: RequestUsage = await chat_model.count_tokens(test_messages, {}, ModelRequestParameters())
3138+
assert usage_result.input_tokens == expected_token_count
3139+
3140+
responses_model = OpenAIResponsesModel(model_name, provider=OpenAIProvider(api_key='foobar'))
3141+
usage_result: RequestUsage = await responses_model.count_tokens(test_messages, {}, ModelRequestParameters())
3142+
assert usage_result.input_tokens == expected_token_count

0 commit comments

Comments
 (0)