|
20 | 20 | from .._utils import ( |
21 | 21 | guard_tool_call_id as _guard_tool_call_id, |
22 | 22 | now_utc as _now_utc, |
23 | | - num_tokens_from_messages, |
24 | 23 | number_to_datetime, |
25 | 24 | ) |
26 | 25 | from ..builtin_tools import CodeExecutionTool, ImageGenerationTool, MCPServerTool, WebSearchTool |
|
59 | 58 | from . import Model, ModelRequestParameters, StreamedResponse, check_allow_model_requests, download_item, get_user_agent |
60 | 59 |
|
61 | 60 | try: |
| 61 | + import tiktoken |
62 | 62 | from openai import NOT_GIVEN, APIStatusError, AsyncOpenAI, AsyncStream |
63 | 63 | from openai.types import AllModels, chat, responses |
64 | 64 | from openai.types.chat import ( |
@@ -918,7 +918,10 @@ async def count_tokens( |
918 | 918 | model_settings: ModelSettings | None, |
919 | 919 | model_request_parameters: ModelRequestParameters, |
920 | 920 | ) -> usage.RequestUsage: |
921 | | - """Make a request to the model for counting tokens.""" |
| 921 | + """Count the number of tokens in the given messages.""" |
| 922 | + if self.system != 'openai': |
| 923 | + raise NotImplementedError('Token counting is only supported for OpenAI system.') |
| 924 | + |
922 | 925 | openai_messages = await self._map_messages(messages, model_request_parameters) |
923 | 926 | token_count = num_tokens_from_messages(openai_messages, self.model_name) |
924 | 927 |
|
@@ -1726,7 +1729,10 @@ async def count_tokens( |
1726 | 1729 | model_settings: ModelSettings | None, |
1727 | 1730 | model_request_parameters: ModelRequestParameters, |
1728 | 1731 | ) -> usage.RequestUsage: |
1729 | | - """Make a request to the model for counting tokens.""" |
| 1732 | + """Count the number of tokens in the given messages.""" |
| 1733 | + if self.system != 'openai': |
| 1734 | + raise NotImplementedError('Token counting is only supported for OpenAI system.') |
| 1735 | + |
1730 | 1736 | _, openai_messages = await self._map_messages( |
1731 | 1737 | messages, cast(OpenAIResponsesModelSettings, model_settings or {}), model_request_parameters |
1732 | 1738 | ) |
@@ -2368,3 +2374,50 @@ def _map_mcp_call( |
2368 | 2374 | provider_name=provider_name, |
2369 | 2375 | ), |
2370 | 2376 | ) |
| 2377 | + |
| 2378 | + |
| 2379 | +def num_tokens_from_messages( |
| 2380 | + messages: list[chat.ChatCompletionMessageParam] | list[responses.ResponseInputItemParam], |
| 2381 | + model: OpenAIModelName, |
| 2382 | +) -> int: |
| 2383 | + """Return the number of tokens used by a list of messages.""" |
| 2384 | + try: |
| 2385 | + encoding = tiktoken.encoding_for_model(model) |
| 2386 | + except KeyError: |
| 2387 | + encoding = tiktoken.get_encoding('o200k_base') |
| 2388 | + if model in { |
| 2389 | + 'gpt-3.5-turbo-0125', |
| 2390 | + 'gpt-4-0314', |
| 2391 | + 'gpt-4-32k-0314', |
| 2392 | + 'gpt-4-0613', |
| 2393 | + 'gpt-4-32k-0613', |
| 2394 | + 'gpt-4o-mini-2024-07-18', |
| 2395 | + 'gpt-4o-2024-08-06', |
| 2396 | + }: |
| 2397 | + tokens_per_message = 3 |
| 2398 | + final_primer = 3 # every reply is primed with <|start|>assistant<|message|> |
| 2399 | + elif model in { |
| 2400 | + 'gpt-5-2025-08-07', |
| 2401 | + }: |
| 2402 | + tokens_per_message = 3 |
| 2403 | + final_primer = 2 |
| 2404 | + elif 'gpt-3.5-turbo' in model: |
| 2405 | + return num_tokens_from_messages(messages, model='gpt-3.5-turbo-0125') |
| 2406 | + elif 'gpt-4o-mini' in model: |
| 2407 | + return num_tokens_from_messages(messages, model='gpt-4o-mini-2024-07-18') |
| 2408 | + elif 'gpt-4o' in model: |
| 2409 | + return num_tokens_from_messages(messages, model='gpt-4o-2024-08-06') |
| 2410 | + elif 'gpt-4' in model: |
| 2411 | + return num_tokens_from_messages(messages, model='gpt-4-0613') |
| 2412 | + elif 'gpt-5' in model: |
| 2413 | + return num_tokens_from_messages(messages, model='gpt-5-2025-08-07') |
| 2414 | + else: |
| 2415 | + raise NotImplementedError(f"""num_tokens_from_messages() is not implemented for model {model}.""") |
| 2416 | + num_tokens = 0 |
| 2417 | + for message in messages: |
| 2418 | + num_tokens += tokens_per_message |
| 2419 | + for value in message.values(): |
| 2420 | + if isinstance(value, str): |
| 2421 | + num_tokens += len(encoding.encode(value)) |
| 2422 | + num_tokens += final_primer |
| 2423 | + return num_tokens |
0 commit comments