Skip to content

Commit c23a286

Browse files
Add support for usage limits (#409)
Co-authored-by: sydney-runkle <[email protected]> Co-authored-by: Sydney Runkle <[email protected]>
1 parent 47d3e5c commit c23a286

23 files changed

+439
-63
lines changed

docs/agents.md

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,90 @@ You can also pass messages from previous runs to continue a conversation or prov
103103

104104
### Additional Configuration
105105

106+
#### Usage Limits
107+
108+
PydanticAI offers a [`settings.UsageLimits`][pydantic_ai.settings.UsageLimits] structure to help you limit your
109+
usage (tokens and/or requests) on model runs.
110+
111+
You can apply these settings by passing the `usage_limits` argument to the `run{_sync,_stream}` functions.
112+
113+
Consider the following example, where we limit the number of response tokens:
114+
115+
```py
116+
from pydantic_ai import Agent
117+
from pydantic_ai.exceptions import UsageLimitExceeded
118+
from pydantic_ai.settings import UsageLimits
119+
120+
agent = Agent('claude-3-5-sonnet-latest')
121+
122+
result_sync = agent.run_sync(
123+
'What is the capital of Italy? Answer with just the city.',
124+
usage_limits=UsageLimits(response_tokens_limit=10),
125+
)
126+
print(result_sync.data)
127+
#> Rome
128+
print(result_sync.usage())
129+
"""
130+
Usage(requests=1, request_tokens=62, response_tokens=1, total_tokens=63, details=None)
131+
"""
132+
133+
try:
134+
result_sync = agent.run_sync(
135+
'What is the capital of Italy? Answer with a paragraph.',
136+
usage_limits=UsageLimits(response_tokens_limit=10),
137+
)
138+
except UsageLimitExceeded as e:
139+
print(e)
140+
#> Exceeded the response_tokens_limit of 10 (response_tokens=32)
141+
```
142+
143+
Restricting the number of requests can be useful in preventing infinite loops or excessive tool calling:
144+
145+
```py
146+
from typing_extensions import TypedDict
147+
148+
from pydantic_ai import Agent, ModelRetry
149+
from pydantic_ai.exceptions import UsageLimitExceeded
150+
from pydantic_ai.settings import UsageLimits
151+
152+
153+
class NeverResultType(TypedDict):
154+
"""
155+
Never ever coerce data to this type.
156+
"""
157+
158+
never_use_this: str
159+
160+
161+
agent = Agent(
162+
'claude-3-5-sonnet-latest',
163+
result_type=NeverResultType,
164+
system_prompt='Any time you get a response, call the `infinite_retry_tool` to produce another response.',
165+
)
166+
167+
168+
@agent.tool_plain(retries=5) # (1)!
169+
def infinite_retry_tool() -> int:
170+
raise ModelRetry('Please try again.')
171+
172+
173+
try:
174+
result_sync = agent.run_sync(
175+
'Begin infinite retry loop!', usage_limits=UsageLimits(request_limit=3) # (2)!
176+
)
177+
except UsageLimitExceeded as e:
178+
print(e)
179+
#> The next request would exceed the request_limit of 3
180+
```
181+
182+
1. This tool has the ability to retry 5 times before erroring, simulating a tool that might get stuck in a loop.
183+
2. This run will error after 3 requests, preventing the infinite tool calling.
184+
185+
!!! note
186+
This is especially relevant if you're registered a lot of tools, `request_limit` can be used to prevent the model from choosing to make too many of these calls.
187+
188+
#### Model (Run) Settings
189+
106190
PydanticAI offers a [`settings.ModelSettings`][pydantic_ai.settings.ModelSettings] structure to help you fine tune your requests.
107191
This structure allows you to configure common parameters that influence the model's behavior, such as `temperature`, `max_tokens`,
108192
`timeout`, and more.

docs/api/models/ollama.md

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,9 @@ result = agent.run_sync('Where were the olympics held in 2012?')
3232
print(result.data)
3333
#> city='London' country='United Kingdom'
3434
print(result.usage())
35-
#> Usage(request_tokens=57, response_tokens=8, total_tokens=65, details=None)
35+
"""
36+
Usage(requests=1, request_tokens=57, response_tokens=8, total_tokens=65, details=None)
37+
"""
3638
```
3739

3840
## Example using a remote server
@@ -60,7 +62,9 @@ result = agent.run_sync('Where were the olympics held in 2012?')
6062
print(result.data)
6163
#> city='London' country='United Kingdom'
6264
print(result.usage())
63-
#> Usage(request_tokens=57, response_tokens=8, total_tokens=65, details=None)
65+
"""
66+
Usage(requests=1, request_tokens=57, response_tokens=8, total_tokens=65, details=None)
67+
"""
6468
```
6569

6670
1. The name of the model running on the remote server

docs/api/settings.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,4 @@
55
inherited_members: true
66
members:
77
- ModelSettings
8+
- UsageLimits

docs/results.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@ result = agent.run_sync('Where were the olympics held in 2012?')
1919
print(result.data)
2020
#> city='London' country='United Kingdom'
2121
print(result.usage())
22-
#> Usage(request_tokens=57, response_tokens=8, total_tokens=65, details=None)
22+
"""
23+
Usage(requests=1, request_tokens=57, response_tokens=8, total_tokens=65, details=None)
24+
"""
2325
```
2426

2527
_(This example is complete, it can be run "as is")_
Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,18 @@
11
from importlib.metadata import version
22

33
from .agent import Agent
4-
from .exceptions import ModelRetry, UnexpectedModelBehavior, UserError
4+
from .exceptions import AgentRunError, ModelRetry, UnexpectedModelBehavior, UsageLimitExceeded, UserError
55
from .tools import RunContext, Tool
66

7-
__all__ = 'Agent', 'Tool', 'RunContext', 'ModelRetry', 'UnexpectedModelBehavior', 'UserError', '__version__'
7+
__all__ = (
8+
'Agent',
9+
'RunContext',
10+
'Tool',
11+
'AgentRunError',
12+
'ModelRetry',
13+
'UnexpectedModelBehavior',
14+
'UsageLimitExceeded',
15+
'UserError',
16+
'__version__',
17+
)
818
__version__ = version('pydantic_ai_slim')

pydantic_ai_slim/pydantic_ai/agent.py

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
result,
2323
)
2424
from .result import ResultData
25-
from .settings import ModelSettings, merge_model_settings
25+
from .settings import ModelSettings, UsageLimits, merge_model_settings
2626
from .tools import (
2727
AgentDeps,
2828
RunContext,
@@ -191,6 +191,7 @@ async def run(
191191
model: models.Model | models.KnownModelName | None = None,
192192
deps: AgentDeps = None,
193193
model_settings: ModelSettings | None = None,
194+
usage_limits: UsageLimits | None = None,
194195
infer_name: bool = True,
195196
) -> result.RunResult[ResultData]:
196197
"""Run the agent with a user prompt in async mode.
@@ -211,8 +212,9 @@ async def run(
211212
message_history: History of the conversation so far.
212213
model: Optional model to use for this run, required if `model` was not set when creating the agent.
213214
deps: Optional dependencies to use for this run.
214-
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
215215
model_settings: Optional settings to use for this model's request.
216+
usage_limits: Optional limits on model request count or token usage.
217+
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
216218
217219
Returns:
218220
The result of the run.
@@ -237,12 +239,14 @@ async def run(
237239
for tool in self._function_tools.values():
238240
tool.current_retry = 0
239241

240-
usage = result.Usage()
241-
242+
usage = result.Usage(requests=0)
242243
model_settings = merge_model_settings(self.model_settings, model_settings)
244+
usage_limits = usage_limits or UsageLimits()
243245

244246
run_step = 0
245247
while True:
248+
usage_limits.check_before_request(usage)
249+
246250
run_step += 1
247251
with _logfire.span('preparing model and tools {run_step=}', run_step=run_step):
248252
agent_model = await self._prepare_model(model_used, deps, messages)
@@ -254,6 +258,8 @@ async def run(
254258

255259
messages.append(model_response)
256260
usage += request_usage
261+
usage.requests += 1
262+
usage_limits.check_tokens(request_usage)
257263

258264
with _logfire.span('handle model response', run_step=run_step) as handle_span:
259265
final_result, tool_responses = await self._handle_model_response(model_response, deps, messages)
@@ -284,6 +290,7 @@ def run_sync(
284290
model: models.Model | models.KnownModelName | None = None,
285291
deps: AgentDeps = None,
286292
model_settings: ModelSettings | None = None,
293+
usage_limits: UsageLimits | None = None,
287294
infer_name: bool = True,
288295
) -> result.RunResult[ResultData]:
289296
"""Run the agent with a user prompt synchronously.
@@ -308,8 +315,9 @@ async def main():
308315
message_history: History of the conversation so far.
309316
model: Optional model to use for this run, required if `model` was not set when creating the agent.
310317
deps: Optional dependencies to use for this run.
311-
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
312318
model_settings: Optional settings to use for this model's request.
319+
usage_limits: Optional limits on model request count or token usage.
320+
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
313321
314322
Returns:
315323
The result of the run.
@@ -322,8 +330,9 @@ async def main():
322330
message_history=message_history,
323331
model=model,
324332
deps=deps,
325-
infer_name=False,
326333
model_settings=model_settings,
334+
usage_limits=usage_limits,
335+
infer_name=False,
327336
)
328337
)
329338

@@ -336,6 +345,7 @@ async def run_stream(
336345
model: models.Model | models.KnownModelName | None = None,
337346
deps: AgentDeps = None,
338347
model_settings: ModelSettings | None = None,
348+
usage_limits: UsageLimits | None = None,
339349
infer_name: bool = True,
340350
) -> AsyncIterator[result.StreamedRunResult[AgentDeps, ResultData]]:
341351
"""Run the agent with a user prompt in async mode, returning a streamed response.
@@ -357,8 +367,9 @@ async def main():
357367
message_history: History of the conversation so far.
358368
model: Optional model to use for this run, required if `model` was not set when creating the agent.
359369
deps: Optional dependencies to use for this run.
360-
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
361370
model_settings: Optional settings to use for this model's request.
371+
usage_limits: Optional limits on model request count or token usage.
372+
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
362373
363374
Returns:
364375
The result of the run.
@@ -387,16 +398,19 @@ async def main():
387398

388399
usage = result.Usage()
389400
model_settings = merge_model_settings(self.model_settings, model_settings)
401+
usage_limits = usage_limits or UsageLimits()
390402

391403
run_step = 0
392404
while True:
393405
run_step += 1
406+
usage_limits.check_before_request(usage)
394407

395408
with _logfire.span('preparing model and tools {run_step=}', run_step=run_step):
396409
agent_model = await self._prepare_model(model_used, deps, messages)
397410

398411
with _logfire.span('model request {run_step=}', run_step=run_step) as model_req_span:
399412
async with agent_model.request_stream(messages, model_settings) as model_response:
413+
usage.requests += 1
400414
model_req_span.set_attribute('response_type', model_response.__class__.__name__)
401415
# We want to end the "model request" span here, but we can't exit the context manager
402416
# in the traditional way
@@ -435,6 +449,7 @@ async def on_complete():
435449
messages,
436450
new_message_index,
437451
usage,
452+
usage_limits,
438453
result_stream,
439454
self._result_schema,
440455
deps,
@@ -456,7 +471,9 @@ async def on_complete():
456471
tool_responses_str = ' '.join(r.part_kind for r in tool_responses)
457472
handle_span.message = f'handle model response -> {tool_responses_str}'
458473
# the model_response should have been fully streamed by now, we can add its usage
459-
usage += model_response.usage()
474+
model_response_usage = model_response.usage()
475+
usage += model_response_usage
476+
usage_limits.check_tokens(usage)
460477

461478
@contextmanager
462479
def override(

pydantic_ai_slim/pydantic_ai/exceptions.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import json
44

5-
__all__ = 'ModelRetry', 'UserError', 'UnexpectedModelBehavior'
5+
__all__ = 'ModelRetry', 'UserError', 'AgentRunError', 'UnexpectedModelBehavior', 'UsageLimitExceeded'
66

77

88
class ModelRetry(Exception):
@@ -30,7 +30,25 @@ def __init__(self, message: str):
3030
super().__init__(message)
3131

3232

33-
class UnexpectedModelBehavior(RuntimeError):
33+
class AgentRunError(RuntimeError):
34+
"""Base class for errors occurring during an agent run."""
35+
36+
message: str
37+
"""The error message."""
38+
39+
def __init__(self, message: str):
40+
self.message = message
41+
super().__init__(message)
42+
43+
def __str__(self) -> str:
44+
return self.message
45+
46+
47+
class UsageLimitExceeded(AgentRunError):
48+
"""Error raised when a Model's usage exceeds the specified limits."""
49+
50+
51+
class UnexpectedModelBehavior(AgentRunError):
3452
"""Error caused by unexpected Model behavior, e.g. an unexpected response code."""
3553

3654
message: str

pydantic_ai_slim/pydantic_ai/models/anthropic.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -282,11 +282,11 @@ def _map_message(messages: list[ModelMessage]) -> tuple[str, list[MessageParam]]
282282
MessageParam(
283283
role='user',
284284
content=[
285-
ToolUseBlockParam(
286-
id=_guard_tool_call_id(t=part, model_source='Anthropic'),
287-
input=part.model_response(),
288-
name=part.tool_name,
289-
type='tool_use',
285+
ToolResultBlockParam(
286+
tool_use_id=_guard_tool_call_id(t=part, model_source='Anthropic'),
287+
type='tool_result',
288+
content=part.model_response(),
289+
is_error=True,
290290
),
291291
],
292292
)

pydantic_ai_slim/pydantic_ai/models/function.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ def get(self, *, final: bool = False) -> ModelResponse:
237237
return ModelResponse(calls, timestamp=self._timestamp)
238238

239239
def usage(self) -> result.Usage:
240-
return result.Usage()
240+
return _estimate_usage([self.get()])
241241

242242
def timestamp(self) -> datetime:
243243
return self._timestamp
@@ -255,24 +255,24 @@ def _estimate_usage(messages: Iterable[ModelMessage]) -> result.Usage:
255255
if isinstance(message, ModelRequest):
256256
for part in message.parts:
257257
if isinstance(part, (SystemPromptPart, UserPromptPart)):
258-
request_tokens += _string_usage(part.content)
258+
request_tokens += _estimate_string_usage(part.content)
259259
elif isinstance(part, ToolReturnPart):
260-
request_tokens += _string_usage(part.model_response_str())
260+
request_tokens += _estimate_string_usage(part.model_response_str())
261261
elif isinstance(part, RetryPromptPart):
262-
request_tokens += _string_usage(part.model_response())
262+
request_tokens += _estimate_string_usage(part.model_response())
263263
else:
264264
assert_never(part)
265265
elif isinstance(message, ModelResponse):
266266
for part in message.parts:
267267
if isinstance(part, TextPart):
268-
response_tokens += _string_usage(part.content)
268+
response_tokens += _estimate_string_usage(part.content)
269269
elif isinstance(part, ToolCallPart):
270270
call = part
271271
if isinstance(call.args, ArgsJson):
272272
args_str = call.args.args_json
273273
else:
274274
args_str = pydantic_core.to_json(call.args.args_dict).decode()
275-
response_tokens += 1 + _string_usage(args_str)
275+
response_tokens += 1 + _estimate_string_usage(args_str)
276276
else:
277277
assert_never(part)
278278
else:
@@ -282,5 +282,5 @@ def _estimate_usage(messages: Iterable[ModelMessage]) -> result.Usage:
282282
)
283283

284284

285-
def _string_usage(content: str) -> int:
285+
def _estimate_string_usage(content: str) -> int:
286286
return len(re.split(r'[\s",.:]+', content))

0 commit comments

Comments
 (0)