Skip to content

Commit 2293595

Browse files
kauabhDouweM
andauthored
Add UsageLimits.count_tokens_before_request using Gemini count_tokens API (#2137)
Co-authored-by: Douwe Maan <[email protected]>
1 parent 5f99595 commit 2293595

File tree

8 files changed

+406
-13
lines changed

8 files changed

+406
-13
lines changed

pydantic_ai_slim/pydantic_ai/_agent_graph.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -351,11 +351,6 @@ async def _prepare_request(
351351
) -> tuple[ModelSettings | None, models.ModelRequestParameters, list[_messages.ModelMessage], RunContext[DepsT]]:
352352
ctx.state.message_history.append(self.request)
353353

354-
# Check usage
355-
if ctx.deps.usage_limits: # pragma: no branch
356-
ctx.deps.usage_limits.check_before_request(ctx.state.usage)
357-
358-
# Increment run_step
359354
ctx.state.run_step += 1
360355

361356
run_context = build_run_context(ctx)
@@ -367,6 +362,18 @@ async def _prepare_request(
367362

368363
message_history = await _process_message_history(ctx.state, ctx.deps.history_processors, run_context)
369364

365+
usage = ctx.state.usage
366+
if ctx.deps.usage_limits.count_tokens_before_request:
367+
# Copy to avoid modifying the original usage object with the counted usage
368+
usage = dataclasses.replace(usage)
369+
370+
counted_usage = await ctx.deps.model.count_tokens(
371+
message_history, ctx.deps.model_settings, model_request_parameters
372+
)
373+
usage.incr(counted_usage)
374+
375+
ctx.deps.usage_limits.check_before_request(usage)
376+
370377
return model_settings, model_request_parameters, message_history, run_context
371378

372379
def _finish_handling(

pydantic_ai_slim/pydantic_ai/models/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -413,6 +413,16 @@ async def request(
413413
"""Make a request to the model."""
414414
raise NotImplementedError()
415415

416+
async def count_tokens(
417+
self,
418+
messages: list[ModelMessage],
419+
model_settings: ModelSettings | None,
420+
model_request_parameters: ModelRequestParameters,
421+
) -> Usage:
422+
"""Make a request to the model for counting tokens."""
423+
# This method is not required, but you need to implement it if you want to support `UsageLimits.count_tokens_before_request`.
424+
raise NotImplementedError(f'Token counting ahead of the request is not supported by {self.__class__.__name__}')
425+
416426
@asynccontextmanager
417427
async def request_stream(
418428
self,

pydantic_ai_slim/pydantic_ai/models/google.py

Lines changed: 66 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -52,13 +52,15 @@
5252
from google.genai.types import (
5353
ContentDict,
5454
ContentUnionDict,
55+
CountTokensConfigDict,
5556
ExecutableCodeDict,
5657
FunctionCallDict,
5758
FunctionCallingConfigDict,
5859
FunctionCallingConfigMode,
5960
FunctionDeclarationDict,
6061
GenerateContentConfigDict,
6162
GenerateContentResponse,
63+
GenerationConfigDict,
6264
GoogleSearchDict,
6365
HttpOptionsDict,
6466
MediaResolution,
@@ -188,6 +190,59 @@ async def request(
188190
response = await self._generate_content(messages, False, model_settings, model_request_parameters)
189191
return self._process_response(response)
190192

193+
async def count_tokens(
194+
self,
195+
messages: list[ModelMessage],
196+
model_settings: ModelSettings | None,
197+
model_request_parameters: ModelRequestParameters,
198+
) -> usage.Usage:
199+
check_allow_model_requests()
200+
model_settings = cast(GoogleModelSettings, model_settings or {})
201+
contents, generation_config = await self._build_content_and_config(
202+
messages, model_settings, model_request_parameters
203+
)
204+
205+
# Annoyingly, the type of `GenerateContentConfigDict.get` is "partially `Unknown`" because `response_schema` includes `typing._UnionGenericAlias`,
206+
# so without this we'd need `pyright: ignore[reportUnknownMemberType]` on every line and wouldn't get type checking anyway.
207+
generation_config = cast(dict[str, Any], generation_config)
208+
209+
config = CountTokensConfigDict(
210+
http_options=generation_config.get('http_options'),
211+
)
212+
if self.system != 'google-gla':
213+
# The fields are not supported by the Gemini API per https://github.com/googleapis/python-genai/blob/7e4ec284dc6e521949626f3ed54028163ef9121d/google/genai/models.py#L1195-L1214
214+
config.update(
215+
system_instruction=generation_config.get('system_instruction'),
216+
tools=cast(list[ToolDict], generation_config.get('tools')),
217+
# Annoyingly, GenerationConfigDict has fewer fields than GenerateContentConfigDict, and no extra fields are allowed.
218+
generation_config=GenerationConfigDict(
219+
temperature=generation_config.get('temperature'),
220+
top_p=generation_config.get('top_p'),
221+
max_output_tokens=generation_config.get('max_output_tokens'),
222+
stop_sequences=generation_config.get('stop_sequences'),
223+
presence_penalty=generation_config.get('presence_penalty'),
224+
frequency_penalty=generation_config.get('frequency_penalty'),
225+
thinking_config=generation_config.get('thinking_config'),
226+
media_resolution=generation_config.get('media_resolution'),
227+
response_mime_type=generation_config.get('response_mime_type'),
228+
response_schema=generation_config.get('response_schema'),
229+
),
230+
)
231+
232+
response = await self.client.aio.models.count_tokens(
233+
model=self._model_name,
234+
contents=contents,
235+
config=config,
236+
)
237+
if response.total_tokens is None:
238+
raise UnexpectedModelBehavior( # pragma: no cover
239+
'Total tokens missing from Gemini response', str(response)
240+
)
241+
return usage.Usage(
242+
request_tokens=response.total_tokens,
243+
total_tokens=response.total_tokens,
244+
)
245+
191246
@asynccontextmanager
192247
async def request_stream(
193248
self,
@@ -265,16 +320,23 @@ async def _generate_content(
265320
model_settings: GoogleModelSettings,
266321
model_request_parameters: ModelRequestParameters,
267322
) -> GenerateContentResponse | Awaitable[AsyncIterator[GenerateContentResponse]]:
268-
tools = self._get_tools(model_request_parameters)
323+
contents, config = await self._build_content_and_config(messages, model_settings, model_request_parameters)
324+
func = self.client.aio.models.generate_content_stream if stream else self.client.aio.models.generate_content
325+
return await func(model=self._model_name, contents=contents, config=config) # type: ignore
269326

327+
async def _build_content_and_config(
328+
self,
329+
messages: list[ModelMessage],
330+
model_settings: GoogleModelSettings,
331+
model_request_parameters: ModelRequestParameters,
332+
) -> tuple[list[ContentUnionDict], GenerateContentConfigDict]:
333+
tools = self._get_tools(model_request_parameters)
270334
response_mime_type = None
271335
response_schema = None
272336
if model_request_parameters.output_mode == 'native':
273337
if tools:
274338
raise UserError('Gemini does not support structured output and tools at the same time.')
275-
276339
response_mime_type = 'application/json'
277-
278340
output_object = model_request_parameters.output_object
279341
assert output_object is not None
280342
response_schema = self._map_response_schema(output_object)
@@ -311,9 +373,7 @@ async def _generate_content(
311373
response_mime_type=response_mime_type,
312374
response_schema=response_schema,
313375
)
314-
315-
func = self.client.aio.models.generate_content_stream if stream else self.client.aio.models.generate_content
316-
return await func(model=self._model_name, contents=contents, config=config) # type: ignore
376+
return contents, config
317377

318378
def _process_response(self, response: GenerateContentResponse) -> ModelResponse:
319379
if not response.candidates or len(response.candidates) != 1:

pydantic_ai_slim/pydantic_ai/usage.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,10 @@ class UsageLimits:
9696
"""The maximum number of tokens allowed in responses from the model."""
9797
total_tokens_limit: int | None = None
9898
"""The maximum number of tokens allowed in requests and responses combined."""
99+
count_tokens_before_request: bool = False
100+
"""If True, perform a token counting pass before sending the request to the model,
101+
to enforce `request_tokens_limit` ahead of time. This may incur additional overhead
102+
(from calling the model's `count_tokens` API before making the actual request) and is disabled by default."""
99103

100104
def has_token_limits(self) -> bool:
101105
"""Returns `True` if this instance places any limits on token counts.
@@ -111,11 +115,23 @@ def has_token_limits(self) -> bool:
111115
)
112116

113117
def check_before_request(self, usage: Usage) -> None:
114-
"""Raises a `UsageLimitExceeded` exception if the next request would exceed the request_limit."""
118+
"""Raises a `UsageLimitExceeded` exception if the next request would exceed any of the limits."""
115119
request_limit = self.request_limit
116120
if request_limit is not None and usage.requests >= request_limit:
117121
raise UsageLimitExceeded(f'The next request would exceed the request_limit of {request_limit}')
118122

123+
request_tokens = usage.request_tokens or 0
124+
if self.request_tokens_limit is not None and request_tokens > self.request_tokens_limit:
125+
raise UsageLimitExceeded(
126+
f'The next request would exceed the request_tokens_limit of {self.request_tokens_limit} ({request_tokens=})'
127+
)
128+
129+
total_tokens = usage.total_tokens or 0
130+
if self.total_tokens_limit is not None and total_tokens > self.total_tokens_limit:
131+
raise UsageLimitExceeded(
132+
f'The next request would exceed the total_tokens_limit of {self.total_tokens_limit} ({total_tokens=})'
133+
)
134+
119135
def check_tokens(self, usage: Usage) -> None:
120136
"""Raises a `UsageLimitExceeded` exception if the usage exceeds any of the token limits."""
121137
request_tokens = usage.request_tokens or 0
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
interactions:
2+
- request:
3+
body: '{"contents": [{"parts": [{"text": "The quick brown fox jumps over the lazydog."}],
4+
"role": "user"}]}'
5+
headers:
6+
Content-Type:
7+
- application/json
8+
user-agent:
9+
- google-genai-sdk/1.26.0 gl-python/3.12.7
10+
x-goog-api-client:
11+
- google-genai-sdk/1.26.0 gl-python/3.12.7
12+
method: post
13+
uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.5-flash:countTokens
14+
response:
15+
body:
16+
string: "{\n \"totalTokens\": 12,\n \"promptTokensDetails\": [\n {\n \"modality\":
17+
\"TEXT\",\n \"tokenCount\": 12\n }\n ]\n}\n"
18+
headers:
19+
Alt-Svc:
20+
- h3=":443"; ma=2592000,h3-29=":443"; ma=2592000
21+
Content-Type:
22+
- application/json; charset=UTF-8
23+
Date:
24+
- Fri, 01 Aug 2025 15:59:25 GMT
25+
Server:
26+
- scaffolding on HTTPServer2
27+
Server-Timing:
28+
- gfet4t7; dur=1582
29+
Transfer-Encoding:
30+
- chunked
31+
Vary:
32+
- Origin
33+
- X-Origin
34+
- Referer
35+
X-Content-Type-Options:
36+
- nosniff
37+
X-Frame-Options:
38+
- SAMEORIGIN
39+
X-XSS-Protection:
40+
- '0'
41+
content-length:
42+
- '117'
43+
status:
44+
code: 200
45+
message: OK
46+
version: 1
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
interactions:
2+
- request:
3+
headers:
4+
accept:
5+
- '*/*'
6+
accept-encoding:
7+
- gzip, deflate
8+
connection:
9+
- keep-alive
10+
content-length:
11+
- '100'
12+
content-type:
13+
- application/json
14+
host:
15+
- generativelanguage.googleapis.com
16+
method: POST
17+
parsed_body:
18+
contents:
19+
- parts:
20+
- text: The quick brown fox jumps over the lazydog.
21+
role: user
22+
uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.5-flash:countTokens
23+
response:
24+
headers:
25+
alt-svc:
26+
- h3=":443"; ma=2592000,h3-29=":443"; ma=2592000
27+
content-length:
28+
- '117'
29+
content-type:
30+
- application/json; charset=UTF-8
31+
server-timing:
32+
- gfet4t7; dur=191
33+
transfer-encoding:
34+
- chunked
35+
vary:
36+
- Origin
37+
- X-Origin
38+
- Referer
39+
parsed_body:
40+
promptTokensDetails:
41+
- modality: TEXT
42+
tokenCount: 12
43+
totalTokens: 12
44+
status:
45+
code: 200
46+
message: OK
47+
- request:
48+
headers:
49+
accept:
50+
- '*/*'
51+
accept-encoding:
52+
- gzip, deflate
53+
connection:
54+
- keep-alive
55+
content-length:
56+
- '124'
57+
content-type:
58+
- application/json
59+
host:
60+
- generativelanguage.googleapis.com
61+
method: POST
62+
parsed_body:
63+
contents:
64+
- parts:
65+
- text: The quick brown fox jumps over the lazydog.
66+
role: user
67+
generationConfig: {}
68+
uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.5-flash:generateContent
69+
response:
70+
headers:
71+
alt-svc:
72+
- h3=":443"; ma=2592000,h3-29=":443"; ma=2592000
73+
content-length:
74+
- '979'
75+
content-type:
76+
- application/json; charset=UTF-8
77+
server-timing:
78+
- gfet4t7; dur=4808
79+
transfer-encoding:
80+
- chunked
81+
vary:
82+
- Origin
83+
- X-Origin
84+
- Referer
85+
parsed_body:
86+
candidates:
87+
- content:
88+
parts:
89+
- text: |-
90+
That's a classic! It's famously known as a **pangram**, which means it's a sentence that contains every letter of the alphabet.
91+
92+
It's often used for:
93+
* **Typing practice:** To ensure all keys are hit.
94+
* **Displaying font samples:** Because it showcases every character.
95+
96+
Just a small note, it's typically written as "lazy dog" (two words) and usually ends with a period:
97+
98+
**The quick brown fox jumps over the lazy dog.**
99+
role: model
100+
finishReason: STOP
101+
index: 0
102+
modelVersion: gemini-2.5-flash
103+
responseId: ZwudaISALoquqtsP9uCG6Qw
104+
usageMetadata:
105+
candidatesTokenCount: 109
106+
promptTokenCount: 12
107+
promptTokensDetails:
108+
- modality: TEXT
109+
tokenCount: 12
110+
thoughtsTokenCount: 806
111+
totalTokenCount: 927
112+
status:
113+
code: 200
114+
message: OK
115+
version: 1

0 commit comments

Comments
 (0)