Skip to content

Commit d78ebc9

Browse files
authored
Fix agent name inference when using run_stream_events (#3279)
1 parent cf55a71 commit d78ebc9

File tree

2 files changed

+5
-3
lines changed

2 files changed

+5
-3
lines changed

pydantic_ai_slim/pydantic_ai/agent/abstract.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -683,6 +683,9 @@ async def main():
683683
An async iterable of stream events `AgentStreamEvent` and finally a `AgentRunResultEvent` with the final
684684
run result.
685685
"""
686+
if infer_name and self.name is None:
687+
self._infer_name(inspect.currentframe())
688+
686689
# unfortunately this hack of returning a generator rather than defining it right here is
687690
# required to allow overloads of this method to work in python's typing system, or at least with pyright
688691
# or at least I couldn't make it work without
@@ -696,7 +699,6 @@ async def main():
696699
model_settings=model_settings,
697700
usage_limits=usage_limits,
698701
usage=usage,
699-
infer_name=infer_name,
700702
toolsets=toolsets,
701703
builtin_tools=builtin_tools,
702704
)
@@ -713,7 +715,6 @@ async def _run_stream_events(
713715
model_settings: ModelSettings | None = None,
714716
usage_limits: _usage.UsageLimits | None = None,
715717
usage: _usage.RunUsage | None = None,
716-
infer_name: bool = True,
717718
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
718719
builtin_tools: Sequence[AbstractBuiltinTool] | None = None,
719720
) -> AsyncIterator[_messages.AgentStreamEvent | AgentRunResultEvent[Any]]:
@@ -739,7 +740,7 @@ async def run_agent() -> AgentRunResult[Any]:
739740
model_settings=model_settings,
740741
usage_limits=usage_limits,
741742
usage=usage,
742-
infer_name=infer_name,
743+
infer_name=False,
743744
toolsets=toolsets,
744745
builtin_tools=builtin_tools,
745746
event_stream_handler=event_stream_handler,

tests/test_streaming.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1817,6 +1817,7 @@ async def ret_a(x: str) -> str:
18171817
return f'{x}-apple'
18181818

18191819
events = [event async for event in test_agent.run_stream_events('Hello')]
1820+
assert test_agent.name == 'test_agent'
18201821

18211822
assert events == snapshot(
18221823
[

0 commit comments

Comments
 (0)