Skip to content

Commit bda1daf

Browse files
authored
Adding Agent.name (#141)
1 parent 8a6c93b commit bda1daf

File tree

5 files changed

+91
-8
lines changed

5 files changed

+91
-8
lines changed

docs/api/agent.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
options:
55
members:
66
- __init__
7+
- name
78
- run
89
- run_sync
910
- run_stream

pydantic_ai_slim/pydantic_ai/agent.py

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@
22

33
import asyncio
44
import dataclasses
5+
import inspect
56
from collections.abc import AsyncIterator, Awaitable, Iterator, Sequence
67
from contextlib import asynccontextmanager, contextmanager
78
from dataclasses import dataclass, field
9+
from types import FrameType
810
from typing import Any, Callable, Generic, cast, final, overload
911

1012
import logfire_api
@@ -54,6 +56,11 @@ class Agent(Generic[AgentDeps, ResultData]):
5456
# dataclass fields mostly for my sanity — knowing what attributes are available
5557
model: models.Model | models.KnownModelName | None
5658
"""The default model configured for this agent."""
59+
name: str | None
60+
"""The name of the agent, used for logging.
61+
62+
If `None`, we try to infer the agent name from the call frame when the agent is first run.
63+
"""
5764
_result_schema: _result.ResultSchema[ResultData] | None = field(repr=False)
5865
_result_validators: list[_result.ResultValidator[AgentDeps, ResultData]] = field(repr=False)
5966
_allow_text_result: bool = field(repr=False)
@@ -79,6 +86,7 @@ def __init__(
7986
result_type: type[ResultData] = str,
8087
system_prompt: str | Sequence[str] = (),
8188
deps_type: type[AgentDeps] = NoneType,
89+
name: str | None = None,
8290
retries: int = 1,
8391
result_tool_name: str = 'final_result',
8492
result_tool_description: str | None = None,
@@ -98,6 +106,8 @@ def __init__(
98106
parameterize the agent, and therefore get the best out of static type checking.
99107
If you're not using deps, but want type checking to pass, you can set `deps=None` to satisfy Pyright
100108
or add a type hint `: Agent[None, <return type>]`.
109+
name: The name of the agent, used for logging. If `None`, we try to infer the agent name from the call frame
110+
when the agent is first run.
101111
retries: The default number of retries to allow before raising an error.
102112
result_tool_name: The name of the tool to use for the final result.
103113
result_tool_description: The description of the final result tool.
@@ -115,6 +125,7 @@ def __init__(
115125
else:
116126
self.model = models.infer_model(model)
117127

128+
self.name = name
118129
self._result_schema = _result.ResultSchema[result_type].build(
119130
result_type, result_tool_name, result_tool_description
120131
)
@@ -139,6 +150,7 @@ async def run(
139150
message_history: list[_messages.Message] | None = None,
140151
model: models.Model | models.KnownModelName | None = None,
141152
deps: AgentDeps = None,
153+
infer_name: bool = True,
142154
) -> result.RunResult[ResultData]:
143155
"""Run the agent with a user prompt in async mode.
144156
@@ -147,16 +159,19 @@ async def run(
147159
message_history: History of the conversation so far.
148160
model: Optional model to use for this run, required if `model` was not set when creating the agent.
149161
deps: Optional dependencies to use for this run.
162+
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
150163
151164
Returns:
152165
The result of the run.
153166
"""
167+
if infer_name and self.name is None:
168+
self._infer_name(inspect.currentframe())
154169
model_used, custom_model, agent_model = await self._get_agent_model(model)
155170

156171
deps = self._get_deps(deps)
157172

158173
with _logfire.span(
159-
'agent run {prompt=}',
174+
'{agent.name} run {prompt=}',
160175
prompt=user_prompt,
161176
agent=self,
162177
custom_model=custom_model,
@@ -208,6 +223,7 @@ def run_sync(
208223
message_history: list[_messages.Message] | None = None,
209224
model: models.Model | models.KnownModelName | None = None,
210225
deps: AgentDeps = None,
226+
infer_name: bool = True,
211227
) -> result.RunResult[ResultData]:
212228
"""Run the agent with a user prompt synchronously.
213229
@@ -218,12 +234,17 @@ def run_sync(
218234
message_history: History of the conversation so far.
219235
model: Optional model to use for this run, required if `model` was not set when creating the agent.
220236
deps: Optional dependencies to use for this run.
237+
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
221238
222239
Returns:
223240
The result of the run.
224241
"""
242+
if infer_name and self.name is None:
243+
self._infer_name(inspect.currentframe())
225244
loop = asyncio.get_event_loop()
226-
return loop.run_until_complete(self.run(user_prompt, message_history=message_history, model=model, deps=deps))
245+
return loop.run_until_complete(
246+
self.run(user_prompt, message_history=message_history, model=model, deps=deps, infer_name=False)
247+
)
227248

228249
@asynccontextmanager
229250
async def run_stream(
@@ -233,6 +254,7 @@ async def run_stream(
233254
message_history: list[_messages.Message] | None = None,
234255
model: models.Model | models.KnownModelName | None = None,
235256
deps: AgentDeps = None,
257+
infer_name: bool = True,
236258
) -> AsyncIterator[result.StreamedRunResult[AgentDeps, ResultData]]:
237259
"""Run the agent with a user prompt in async mode, returning a streamed response.
238260
@@ -241,16 +263,21 @@ async def run_stream(
241263
message_history: History of the conversation so far.
242264
model: Optional model to use for this run, required if `model` was not set when creating the agent.
243265
deps: Optional dependencies to use for this run.
266+
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
244267
245268
Returns:
246269
The result of the run.
247270
"""
271+
if infer_name and self.name is None:
272+
# f_back because `asynccontextmanager` adds one frame
273+
if frame := inspect.currentframe(): # pragma: no branch
274+
self._infer_name(frame.f_back)
248275
model_used, custom_model, agent_model = await self._get_agent_model(model)
249276

250277
deps = self._get_deps(deps)
251278

252279
with _logfire.span(
253-
'agent run stream {prompt=}',
280+
'{agent.name} run stream {prompt=}',
254281
prompt=user_prompt,
255282
agent=self,
256283
custom_model=custom_model,
@@ -798,6 +825,19 @@ def _get_deps(self, deps: AgentDeps) -> AgentDeps:
798825
else:
799826
return deps
800827

828+
def _infer_name(self, function_frame: FrameType | None) -> None:
829+
"""Infer the agent name from the call frame.
830+
831+
Usage should be `self._infer_name(inspect.currentframe())`.
832+
"""
833+
assert self.name is None, 'Name already set'
834+
if function_frame is not None: # pragma: no branch
835+
if parent_frame := function_frame.f_back: # pragma: no branch
836+
for name, item in parent_frame.f_locals.items():
837+
if item is self:
838+
self.name = name
839+
return
840+
801841

802842
@dataclass
803843
class _MarkFinalResult(Generic[ResultData]):

tests/test_agent.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727

2828
from .conftest import IsNow, TestEnv
2929

30+
pytestmark = pytest.mark.anyio
31+
3032

3133
def test_result_tuple(set_event_loop: None):
3234
def return_tuple(_: list[Message], info: AgentInfo) -> ModelAnyResponse:
@@ -69,7 +71,10 @@ def return_model(messages: list[Message], info: AgentInfo) -> ModelAnyResponse:
6971

7072
agent = Agent(FunctionModel(return_model), result_type=Foo)
7173

74+
assert agent.name is None
75+
7276
result = agent.run_sync('Hello')
77+
assert agent.name == 'agent'
7378
assert isinstance(result.data, Foo)
7479
assert result.data.model_dump() == {'a': 42, 'b': 'foo'}
7580
assert result.all_messages() == snapshot(
@@ -535,3 +540,37 @@ async def make_request() -> str:
535540
for _ in range(2):
536541
result = agent.run_sync('Hello')
537542
assert result.data == '{"make_request":"200"}'
543+
544+
545+
async def test_agent_name():
546+
my_agent = Agent('test')
547+
548+
assert my_agent.name is None
549+
550+
await my_agent.run('Hello', infer_name=False)
551+
assert my_agent.name is None
552+
553+
await my_agent.run('Hello')
554+
assert my_agent.name == 'my_agent'
555+
556+
557+
async def test_agent_name_already_set():
558+
my_agent = Agent('test', name='fig_tree')
559+
560+
assert my_agent.name == 'fig_tree'
561+
562+
await my_agent.run('Hello')
563+
assert my_agent.name == 'fig_tree'
564+
565+
566+
async def test_agent_name_changes():
567+
my_agent = Agent('test')
568+
569+
await my_agent.run('Hello')
570+
assert my_agent.name == 'my_agent'
571+
572+
new_agent = my_agent
573+
del my_agent
574+
575+
await new_agent.run('Hello')
576+
assert new_agent.name == 'my_agent'

tests/test_streaming.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,15 @@
3131
async def test_streamed_text_response():
3232
m = TestModel()
3333

34-
agent = Agent(m)
34+
test_agent = Agent(m)
35+
assert test_agent.name is None
3536

36-
@agent.tool_plain
37+
@test_agent.tool_plain
3738
async def ret_a(x: str) -> str:
3839
return f'{x}-apple'
3940

40-
async with agent.run_stream('Hello') as result:
41+
async with test_agent.run_stream('Hello') as result:
42+
assert test_agent.name == 'test_agent'
4143
assert not result.is_structured
4244
assert not result.is_complete
4345
assert result.all_messages() == snapshot(
@@ -71,9 +73,10 @@ async def ret_a(x: str) -> str:
7173
async def test_streamed_structured_response():
7274
m = TestModel()
7375

74-
agent = Agent(m, result_type=tuple[str, str])
76+
agent = Agent(m, result_type=tuple[str, str], name='fig_jam')
7577

7678
async with agent.run_stream('') as result:
79+
assert agent.name == 'fig_jam'
7780
assert result.is_structured
7881
assert not result.is_complete
7982
response = await result.get_data()

tests/test_tools.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -335,7 +335,7 @@ def test_tool_return_conflict():
335335

336336

337337
def test_init_ctx_tool_invalid():
338-
def plain_tool(x: int) -> int:
338+
def plain_tool(x: int) -> int: # pragma: no cover
339339
return x + 1
340340

341341
m = r'First parameter of tools that take context must be annotated with RunContext\[\.\.\.\]'

0 commit comments

Comments
 (0)