Skip to content

Commit 6e06684

Browse files
authored
multi-agent usage (#538)
1 parent 718c9a0 commit 6e06684

File tree

5 files changed

+122
-34
lines changed

5 files changed

+122
-34
lines changed

pydantic_ai_slim/pydantic_ai/agent.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,7 @@ async def run(
184184
deps: AgentDeps = None,
185185
model_settings: ModelSettings | None = None,
186186
usage_limits: UsageLimits | None = None,
187+
usage: result.Usage | None = None,
187188
infer_name: bool = True,
188189
) -> result.RunResult[ResultData]:
189190
"""Run the agent with a user prompt in async mode.
@@ -206,6 +207,7 @@ async def run(
206207
deps: Optional dependencies to use for this run.
207208
model_settings: Optional settings to use for this model's request.
208209
usage_limits: Optional limits on model request count or token usage.
210+
usage: Optional usage to start with, useful for resuming a conversation or agents used in tools.
209211
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
210212
211213
Returns:
@@ -226,20 +228,19 @@ async def run(
226228
model_name=model_used.name(),
227229
agent_name=self.name or 'agent',
228230
) as run_span:
229-
run_context = RunContext(deps, 0, [], None, model_used)
231+
run_context = RunContext(deps, 0, [], None, model_used, usage or result.Usage())
230232
messages = await self._prepare_messages(user_prompt, message_history, run_context)
231233
run_context.messages = messages
232234

233235
for tool in self._function_tools.values():
234236
tool.current_retry = 0
235237

236-
usage = result.Usage(requests=0)
237238
model_settings = merge_model_settings(self.model_settings, model_settings)
238239
usage_limits = usage_limits or UsageLimits()
239240

240241
run_step = 0
241242
while True:
242-
usage_limits.check_before_request(usage)
243+
usage_limits.check_before_request(run_context.usage)
243244

244245
run_step += 1
245246
with _logfire.span('preparing model and tools {run_step=}', run_step=run_step):
@@ -251,9 +252,8 @@ async def run(
251252
model_req_span.set_attribute('usage', request_usage)
252253

253254
messages.append(model_response)
254-
usage += request_usage
255-
usage.requests += 1
256-
usage_limits.check_tokens(request_usage)
255+
run_context.usage.incr(request_usage, requests=1)
256+
usage_limits.check_tokens(run_context.usage)
257257

258258
with _logfire.span('handle model response', run_step=run_step) as handle_span:
259259
final_result, tool_responses = await self._handle_model_response(model_response, run_context)
@@ -266,10 +266,10 @@ async def run(
266266
if final_result is not None:
267267
result_data = final_result.data
268268
run_span.set_attribute('all_messages', messages)
269-
run_span.set_attribute('usage', usage)
269+
run_span.set_attribute('usage', run_context.usage)
270270
handle_span.set_attribute('result', result_data)
271271
handle_span.message = 'handle model response -> final result'
272-
return result.RunResult(messages, new_message_index, result_data, usage)
272+
return result.RunResult(messages, new_message_index, result_data, run_context.usage)
273273
else:
274274
# continue the conversation
275275
handle_span.set_attribute('tool_responses', tool_responses)
@@ -285,6 +285,7 @@ def run_sync(
285285
deps: AgentDeps = None,
286286
model_settings: ModelSettings | None = None,
287287
usage_limits: UsageLimits | None = None,
288+
usage: result.Usage | None = None,
288289
infer_name: bool = True,
289290
) -> result.RunResult[ResultData]:
290291
"""Run the agent with a user prompt synchronously.
@@ -311,6 +312,7 @@ async def main():
311312
deps: Optional dependencies to use for this run.
312313
model_settings: Optional settings to use for this model's request.
313314
usage_limits: Optional limits on model request count or token usage.
315+
usage: Optional usage to start with, useful for resuming a conversation or agents used in tools.
314316
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
315317
316318
Returns:
@@ -326,6 +328,7 @@ async def main():
326328
deps=deps,
327329
model_settings=model_settings,
328330
usage_limits=usage_limits,
331+
usage=usage,
329332
infer_name=False,
330333
)
331334
)
@@ -340,6 +343,7 @@ async def run_stream(
340343
deps: AgentDeps = None,
341344
model_settings: ModelSettings | None = None,
342345
usage_limits: UsageLimits | None = None,
346+
usage: result.Usage | None = None,
343347
infer_name: bool = True,
344348
) -> AsyncIterator[result.StreamedRunResult[AgentDeps, ResultData]]:
345349
"""Run the agent with a user prompt in async mode, returning a streamed response.
@@ -363,6 +367,7 @@ async def main():
363367
deps: Optional dependencies to use for this run.
364368
model_settings: Optional settings to use for this model's request.
365369
usage_limits: Optional limits on model request count or token usage.
370+
usage: Optional usage to start with, useful for resuming a conversation or agents used in tools.
366371
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
367372
368373
Returns:
@@ -385,28 +390,27 @@ async def main():
385390
model_name=model_used.name(),
386391
agent_name=self.name or 'agent',
387392
) as run_span:
388-
run_context = RunContext(deps, 0, [], None, model_used)
393+
run_context = RunContext(deps, 0, [], None, model_used, usage or result.Usage())
389394
messages = await self._prepare_messages(user_prompt, message_history, run_context)
390395
run_context.messages = messages
391396

392397
for tool in self._function_tools.values():
393398
tool.current_retry = 0
394399

395-
usage = result.Usage()
396400
model_settings = merge_model_settings(self.model_settings, model_settings)
397401
usage_limits = usage_limits or UsageLimits()
398402

399403
run_step = 0
400404
while True:
401405
run_step += 1
402-
usage_limits.check_before_request(usage)
406+
usage_limits.check_before_request(run_context.usage)
403407

404408
with _logfire.span('preparing model and tools {run_step=}', run_step=run_step):
405409
agent_model = await self._prepare_model(run_context)
406410

407411
with _logfire.span('model request {run_step=}', run_step=run_step) as model_req_span:
408412
async with agent_model.request_stream(messages, model_settings) as model_response:
409-
usage.requests += 1
413+
run_context.usage.requests += 1
410414
model_req_span.set_attribute('response_type', model_response.__class__.__name__)
411415
# We want to end the "model request" span here, but we can't exit the context manager
412416
# in the traditional way
@@ -442,7 +446,6 @@ async def on_complete():
442446
yield result.StreamedRunResult(
443447
messages,
444448
new_message_index,
445-
usage,
446449
usage_limits,
447450
result_stream,
448451
self._result_schema,
@@ -466,8 +469,8 @@ async def on_complete():
466469
handle_span.message = f'handle model response -> {tool_responses_str}'
467470
# the model_response should have been fully streamed by now, we can add its usage
468471
model_response_usage = model_response.usage()
469-
usage += model_response_usage
470-
usage_limits.check_tokens(usage)
472+
run_context.usage.incr(model_response_usage)
473+
usage_limits.check_tokens(run_context.usage)
471474

472475
@contextmanager
473476
def override(

pydantic_ai_slim/pydantic_ai/result.py

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from abc import ABC, abstractmethod
44
from collections.abc import AsyncIterator, Awaitable, Callable
5+
from copy import copy
56
from dataclasses import dataclass, field
67
from datetime import datetime
78
from typing import Generic, Union, cast
@@ -63,25 +64,33 @@ class Usage:
6364
details: dict[str, int] | None = None
6465
"""Any extra details returned by the model."""
6566

66-
def __add__(self, other: Usage) -> Usage:
67-
"""Add two Usages together.
67+
def incr(self, incr_usage: Usage, *, requests: int = 0) -> None:
68+
"""Increment the usage in place.
6869
69-
This is provided so it's trivial to sum usage information from multiple requests and runs.
70+
Args:
71+
incr_usage: The usage to increment by.
72+
requests: The number of requests to increment by in addition to `incr_usage.requests`.
7073
"""
71-
counts: dict[str, int] = {}
74+
self.requests += requests
7275
for f in 'requests', 'request_tokens', 'response_tokens', 'total_tokens':
7376
self_value = getattr(self, f)
74-
other_value = getattr(other, f)
77+
other_value = getattr(incr_usage, f)
7578
if self_value is not None or other_value is not None:
76-
counts[f] = (self_value or 0) + (other_value or 0)
79+
setattr(self, f, (self_value or 0) + (other_value or 0))
7780

78-
details = self.details.copy() if self.details is not None else None
79-
if other.details is not None:
80-
details = details or {}
81-
for key, value in other.details.items():
82-
details[key] = details.get(key, 0) + value
81+
if incr_usage.details:
82+
self.details = self.details or {}
83+
for key, value in incr_usage.details.items():
84+
self.details[key] = self.details.get(key, 0) + value
8385

84-
return Usage(**counts, details=details or None)
86+
def __add__(self, other: Usage) -> Usage:
87+
"""Add two Usages together.
88+
89+
This is provided so it's trivial to sum usage information from multiple requests and runs.
90+
"""
91+
new_usage = copy(self)
92+
new_usage.incr(other)
93+
return new_usage
8594

8695

8796
@dataclass
@@ -136,8 +145,6 @@ def usage(self) -> Usage:
136145
class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultData]):
137146
"""Result of a streamed run that returns structured data via a tool call."""
138147

139-
usage_so_far: Usage
140-
"""Usage of the run up until the last request."""
141148
_usage_limits: UsageLimits | None
142149
_stream_response: models.EitherStreamedResponse
143150
_result_schema: _result.ResultSchema[ResultData] | None
@@ -306,7 +313,7 @@ def usage(self) -> Usage:
306313
!!! note
307314
This won't return the full usage until the stream is finished.
308315
"""
309-
return self.usage_so_far + self._stream_response.usage()
316+
return self._run_ctx.usage + self._stream_response.usage()
310317

311318
def timestamp(self) -> datetime:
312319
"""Get the timestamp of the response."""

pydantic_ai_slim/pydantic_ai/settings.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,6 @@ def check_tokens(self, usage: Usage) -> None:
136136
f'Exceeded the response_tokens_limit of {self.response_tokens_limit} ({response_tokens=})'
137137
)
138138

139-
total_tokens = request_tokens + response_tokens
139+
total_tokens = usage.total_tokens or 0
140140
if self.total_tokens_limit is not None and total_tokens > self.total_tokens_limit:
141141
raise UsageLimitExceeded(f'Exceeded the total_tokens_limit of {self.total_tokens_limit} ({total_tokens=})')

pydantic_ai_slim/pydantic_ai/tools.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import inspect
55
from collections.abc import Awaitable
66
from dataclasses import dataclass, field
7-
from typing import Any, Callable, Generic, TypeVar, Union, cast
7+
from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, Union, cast
88

99
from pydantic import ValidationError
1010
from pydantic_core import SchemaValidator
@@ -13,6 +13,9 @@
1313
from . import _pydantic, _utils, messages as _messages, models
1414
from .exceptions import ModelRetry, UnexpectedModelBehavior
1515

16+
if TYPE_CHECKING:
17+
from .result import Usage
18+
1619
__all__ = (
1720
'AgentDeps',
1821
'RunContext',
@@ -45,6 +48,8 @@ class RunContext(Generic[AgentDeps]):
4548
"""Name of the tool being called."""
4649
model: models.Model
4750
"""The model used in this run."""
51+
usage: Usage
52+
"""LLM usage associated with the run."""
4853

4954
def replace_with(
5055
self, retry: int | None = None, tool_name: str | None | _utils.Unset = _utils.UNSET

tests/test_usage_limits.py

Lines changed: 75 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,20 @@
1+
import functools
2+
import operator
13
import re
24
from datetime import timezone
35

46
import pytest
57
from inline_snapshot import snapshot
68

7-
from pydantic_ai import Agent, UsageLimitExceeded
8-
from pydantic_ai.messages import ArgsDict, ModelRequest, ModelResponse, ToolCallPart, ToolReturnPart, UserPromptPart
9+
from pydantic_ai import Agent, RunContext, UsageLimitExceeded
10+
from pydantic_ai.messages import (
11+
ArgsDict,
12+
ModelRequest,
13+
ModelResponse,
14+
ToolCallPart,
15+
ToolReturnPart,
16+
UserPromptPart,
17+
)
918
from pydantic_ai.models.test import TestModel
1019
from pydantic_ai.result import Usage
1120
from pydantic_ai.settings import UsageLimits
@@ -97,3 +106,67 @@ async def ret_a(x: str) -> str:
97106
UsageLimitExceeded, match=re.escape('Exceeded the response_tokens_limit of 10 (response_tokens=11)')
98107
):
99108
await result.get_data()
109+
110+
111+
def test_usage_so_far(set_event_loop: None) -> None:
112+
test_agent = Agent(TestModel())
113+
114+
with pytest.raises(
115+
UsageLimitExceeded, match=re.escape('Exceeded the total_tokens_limit of 105 (total_tokens=163)')
116+
):
117+
test_agent.run_sync(
118+
'Hello, this prompt exceeds the request tokens limit.',
119+
usage_limits=UsageLimits(total_tokens_limit=105),
120+
usage=Usage(total_tokens=100),
121+
)
122+
123+
124+
async def test_multi_agent_usage_no_incr():
125+
delegate_agent = Agent(TestModel(), result_type=int)
126+
127+
controller_agent1 = Agent(TestModel())
128+
run_1_usages: list[Usage] = []
129+
130+
@controller_agent1.tool
131+
async def delegate_to_other_agent1(ctx: RunContext[None], sentence: str) -> int:
132+
delegate_result = await delegate_agent.run(sentence)
133+
delegate_usage = delegate_result.usage()
134+
run_1_usages.append(delegate_usage)
135+
assert delegate_usage == snapshot(Usage(requests=1, request_tokens=51, response_tokens=4, total_tokens=55))
136+
return delegate_result.data
137+
138+
result1 = await controller_agent1.run('foobar')
139+
assert result1.data == snapshot('{"delegate_to_other_agent1":0}')
140+
run_1_usages.append(result1.usage())
141+
assert result1.usage() == snapshot(Usage(requests=2, request_tokens=103, response_tokens=13, total_tokens=116))
142+
143+
controller_agent2 = Agent(TestModel())
144+
145+
@controller_agent2.tool
146+
async def delegate_to_other_agent2(ctx: RunContext[None], sentence: str) -> int:
147+
delegate_result = await delegate_agent.run(sentence, usage=ctx.usage)
148+
delegate_usage = delegate_result.usage()
149+
assert delegate_usage == snapshot(Usage(requests=2, request_tokens=102, response_tokens=9, total_tokens=111))
150+
return delegate_result.data
151+
152+
result2 = await controller_agent2.run('foobar')
153+
assert result2.data == snapshot('{"delegate_to_other_agent2":0}')
154+
assert result2.usage() == snapshot(Usage(requests=3, request_tokens=154, response_tokens=17, total_tokens=171))
155+
156+
# confirm the usage from result2 is the sum of the usage from result1
157+
assert result2.usage() == functools.reduce(operator.add, run_1_usages)
158+
159+
160+
async def test_multi_agent_usage_sync():
161+
"""As in `test_multi_agent_usage_async`, with a sync tool."""
162+
controller_agent = Agent(TestModel())
163+
164+
@controller_agent.tool
165+
def delegate_to_other_agent(ctx: RunContext[None], sentence: str) -> int:
166+
new_usage = Usage(requests=5, request_tokens=2, response_tokens=3, total_tokens=4)
167+
ctx.usage.incr(new_usage)
168+
return 0
169+
170+
result = await controller_agent.run('foobar')
171+
assert result.data == snapshot('{"delegate_to_other_agent":0}')
172+
assert result.usage() == snapshot(Usage(requests=7, request_tokens=105, response_tokens=16, total_tokens=120))

0 commit comments

Comments
 (0)