Skip to content

Commit 46cd331

Browse files
committed
address some of the comments
1 parent 6396f5d commit 46cd331

File tree

4 files changed

+200085
-50
lines changed

4 files changed

+200085
-50
lines changed

pydantic_ai_slim/pydantic_ai/_utils.py

Lines changed: 0 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from types import GenericAlias
1515
from typing import TYPE_CHECKING, Any, Generic, TypeAlias, TypeGuard, TypeVar, get_args, get_origin, overload
1616

17-
import tiktoken
1817
from anyio.to_thread import run_sync
1918
from pydantic import BaseModel, TypeAdapter
2019
from pydantic.json_schema import JsonSchemaValue
@@ -33,14 +32,10 @@
3332
AbstractSpan = AbstractSpan
3433

3534
if TYPE_CHECKING:
36-
from openai.types.chat import ChatCompletionMessageParam
37-
from openai.types.responses.response_input_item_param import ResponseInputItemParam
38-
3935
from pydantic_ai.agent import AgentRun, AgentRunResult
4036
from pydantic_graph import GraphRun, GraphRunResult
4137

4238
from . import messages as _messages
43-
from .models.openai import OpenAIModelName
4439
from .tools import ObjectJsonSchema
4540

4641
_P = ParamSpec('_P')
@@ -512,45 +507,3 @@ def get_event_loop():
512507
event_loop = asyncio.new_event_loop()
513508
asyncio.set_event_loop(event_loop)
514509
return event_loop
515-
516-
517-
def num_tokens_from_messages(
518-
messages: list[ChatCompletionMessageParam] | list[ResponseInputItemParam],
519-
model: OpenAIModelName = 'gpt-4o-mini-2024-07-18',
520-
) -> int:
521-
"""Return the number of tokens used by a list of messages."""
522-
try:
523-
encoding = tiktoken.encoding_for_model(model)
524-
except KeyError:
525-
print('Warning: model not found. Using o200k_base encoding.') # TODO: How to handle warnings?
526-
encoding = tiktoken.get_encoding('o200k_base')
527-
if model in {
528-
'gpt-3.5-turbo-0125',
529-
'gpt-4-0314',
530-
'gpt-4-32k-0314',
531-
'gpt-4-0613',
532-
'gpt-4-32k-0613',
533-
'gpt-4o-mini-2024-07-18',
534-
'gpt-4o-2024-08-06',
535-
}:
536-
tokens_per_message = 3
537-
elif 'gpt-3.5-turbo' in model:
538-
return num_tokens_from_messages(messages, model='gpt-3.5-turbo-0125')
539-
elif 'gpt-4o-mini' in model:
540-
return num_tokens_from_messages(messages, model='gpt-4o-mini-2024-07-18')
541-
elif 'gpt-4o' in model:
542-
return num_tokens_from_messages(messages, model='gpt-4o-2024-08-06')
543-
elif 'gpt-4' in model:
544-
return num_tokens_from_messages(messages, model='gpt-4-0613')
545-
else:
546-
raise NotImplementedError(
547-
f"""num_tokens_from_messages() is not implemented for model {model}."""
548-
) # TODO: How to handle other models?
549-
num_tokens = 0
550-
for message in messages:
551-
num_tokens += tokens_per_message
552-
for value in message.values():
553-
if isinstance(value, str):
554-
num_tokens += len(encoding.encode(value))
555-
num_tokens += 3 # every reply is primed with <|start|>assistant<|message|>
556-
return num_tokens

pydantic_ai_slim/pydantic_ai/models/openai.py

Lines changed: 56 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
from .._utils import (
2121
guard_tool_call_id as _guard_tool_call_id,
2222
now_utc as _now_utc,
23-
num_tokens_from_messages,
2423
number_to_datetime,
2524
)
2625
from ..builtin_tools import CodeExecutionTool, ImageGenerationTool, MCPServerTool, WebSearchTool
@@ -59,6 +58,7 @@
5958
from . import Model, ModelRequestParameters, StreamedResponse, check_allow_model_requests, download_item, get_user_agent
6059

6160
try:
61+
import tiktoken
6262
from openai import NOT_GIVEN, APIStatusError, AsyncOpenAI, AsyncStream
6363
from openai.types import AllModels, chat, responses
6464
from openai.types.chat import (
@@ -918,7 +918,10 @@ async def count_tokens(
918918
model_settings: ModelSettings | None,
919919
model_request_parameters: ModelRequestParameters,
920920
) -> usage.RequestUsage:
921-
"""Make a request to the model for counting tokens."""
921+
"""Count the number of tokens in the given messages."""
922+
if self.system != 'openai':
923+
raise NotImplementedError('Token counting is only supported for OpenAI system.')
924+
922925
openai_messages = await self._map_messages(messages, model_request_parameters)
923926
token_count = num_tokens_from_messages(openai_messages, self.model_name)
924927

@@ -1726,7 +1729,10 @@ async def count_tokens(
17261729
model_settings: ModelSettings | None,
17271730
model_request_parameters: ModelRequestParameters,
17281731
) -> usage.RequestUsage:
1729-
"""Make a request to the model for counting tokens."""
1732+
"""Count the number of tokens in the given messages."""
1733+
if self.system != 'openai':
1734+
raise NotImplementedError('Token counting is only supported for OpenAI system.')
1735+
17301736
_, openai_messages = await self._map_messages(
17311737
messages, cast(OpenAIResponsesModelSettings, model_settings or {}), model_request_parameters
17321738
)
@@ -2368,3 +2374,50 @@ def _map_mcp_call(
23682374
provider_name=provider_name,
23692375
),
23702376
)
2377+
2378+
2379+
def num_tokens_from_messages(
2380+
messages: list[chat.ChatCompletionMessageParam] | list[responses.ResponseInputItemParam],
2381+
model: OpenAIModelName,
2382+
) -> int:
2383+
"""Return the number of tokens used by a list of messages."""
2384+
try:
2385+
encoding = tiktoken.encoding_for_model(model)
2386+
except KeyError:
2387+
encoding = tiktoken.get_encoding('o200k_base')
2388+
if model in {
2389+
'gpt-3.5-turbo-0125',
2390+
'gpt-4-0314',
2391+
'gpt-4-32k-0314',
2392+
'gpt-4-0613',
2393+
'gpt-4-32k-0613',
2394+
'gpt-4o-mini-2024-07-18',
2395+
'gpt-4o-2024-08-06',
2396+
}:
2397+
tokens_per_message = 3
2398+
final_primer = 3 # every reply is primed with <|start|>assistant<|message|>
2399+
elif model in {
2400+
'gpt-5-2025-08-07',
2401+
}:
2402+
tokens_per_message = 3
2403+
final_primer = 2
2404+
elif 'gpt-3.5-turbo' in model:
2405+
return num_tokens_from_messages(messages, model='gpt-3.5-turbo-0125')
2406+
elif 'gpt-4o-mini' in model:
2407+
return num_tokens_from_messages(messages, model='gpt-4o-mini-2024-07-18')
2408+
elif 'gpt-4o' in model:
2409+
return num_tokens_from_messages(messages, model='gpt-4o-2024-08-06')
2410+
elif 'gpt-4' in model:
2411+
return num_tokens_from_messages(messages, model='gpt-4-0613')
2412+
elif 'gpt-5' in model:
2413+
return num_tokens_from_messages(messages, model='gpt-5-2025-08-07')
2414+
else:
2415+
raise NotImplementedError(f"""num_tokens_from_messages() is not implemented for model {model}.""")
2416+
num_tokens = 0
2417+
for message in messages:
2418+
num_tokens += tokens_per_message
2419+
for value in message.values():
2420+
if isinstance(value, str):
2421+
num_tokens += len(encoding.encode(value))
2422+
num_tokens += final_primer
2423+
return num_tokens

0 commit comments

Comments
 (0)