Skip to content

Commit 4b46b1e

Browse files
committed
Progress
1 parent 2ab33e8 commit 4b46b1e

File tree

14 files changed

+415
-127
lines changed

14 files changed

+415
-127
lines changed

pydantic_ai_slim/pydantic_ai/_agent_graph.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeGuard, cast
1414

1515
import anyio
16+
from anyio.streams.memory import MemoryObjectSendStream
1617
from opentelemetry.trace import Tracer
1718
from typing_extensions import TypeVar, assert_never
1819

@@ -599,7 +600,7 @@ async def _run(): # noqa: C901
599600
if text:
600601
try:
601602
self._next_node = await self._handle_text_response(
602-
ctx, text, text_processor
603+
ctx, text, text_processor, send_stream
603604
)
604605
return
605606
except ToolRetryError:
@@ -668,8 +669,9 @@ async def _run(): # noqa: C901
668669

669670
if text_processor := output_schema.text_processor:
670671
if text:
671-
# TODO (DouweM): This could call an output function that yields custom events, but we're not in an event stream here?
672-
self._next_node = await self._handle_text_response(ctx, text, text_processor)
672+
self._next_node = await self._handle_text_response(
673+
ctx, text, text_processor, send_stream
674+
)
673675
return
674676
alternatives.insert(0, 'return text')
675677

@@ -737,8 +739,10 @@ async def _handle_text_response(
737739
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
738740
text: str,
739741
text_processor: _output.BaseOutputProcessor[NodeRunEndT],
742+
event_stream: MemoryObjectSendStream[_messages.CustomEvent],
740743
) -> ModelRequestNode[DepsT, NodeRunEndT] | End[result.FinalResult[NodeRunEndT]]:
741744
run_context = build_run_context(ctx)
745+
run_context = replace(run_context, event_stream=event_stream)
742746

743747
result_data = await text_processor.process(text, run_context)
744748

pydantic_ai_slim/pydantic_ai/_function_schema.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,9 @@ async def call(self, args_dict: dict[str, Any], ctx: RunContext[Any]) -> Any:
6868
return_value = cast(Return, event_data)
6969
continue
7070

71-
# If there's no event stream, we're being called inside `agent.run_stream`, after event streaming has completed
72-
# and final result streaming has begun, so there's no need to yield custom events.
73-
# TODO (DouweM): Should we store events on the ToolReturnPart/RetryPromptPart or something, like we allow `metadata`?
71+
# If there's no event stream, we're being called from inside `agent.run_stream()` or `AgentStream.get_output()`,
72+
# after event streaming has completed and final result streaming has begun, so there's nowhere to yield custom events to.
73+
# We could consider storing the yielded events somewhere and letting them be accessed after the fact as a list.
7474
if ctx.event_stream is not None:
7575
if isinstance(event_data, CustomEvent):
7676
event = cast(CustomEvent, event_data)

pydantic_ai_slim/pydantic_ai/_tool_manager.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,6 @@ async def for_run_step(self, ctx: RunContext[AgentDepsT]) -> ToolManager[AgentDe
5050
"""Build a new tool manager for the next run step, carrying over the retries from the current run step."""
5151
if self.ctx is not None:
5252
if ctx.run_step == self.ctx.run_step:
53-
# TODO (DouweM): Refactor to make sure it's always set
54-
55-
if ctx.event_stream and not self.ctx.event_stream:
56-
self.ctx.event_stream = ctx.event_stream
5753
return self
5854

5955
retries = {

pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_toolset.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from pydantic_ai import AbstractToolset, FunctionToolset, WrapperToolset
1313
from pydantic_ai.exceptions import ApprovalRequired, CallDeferred, ModelRetry
14+
from pydantic_ai.messages import Return, ToolReturn
1415
from pydantic_ai.tools import AgentDepsT, ToolDefinition
1516

1617
from ._run_context import TemporalRunContext
@@ -43,7 +44,7 @@ class _ModelRetry:
4344

4445
@dataclass
4546
class _ToolReturn:
46-
result: Any
47+
result: ToolReturn[Any] | Any
4748
kind: Literal['tool_return'] = 'tool_return'
4849

4950

@@ -74,6 +75,9 @@ def visit_and_replace(
7475
async def _wrap_call_tool_result(self, coro: Awaitable[Any]) -> CallToolResult:
7576
try:
7677
result = await coro
78+
if type(result) is Return:
79+
# We don't use `isinstance` because `ToolReturn` is a subclass of `Return` with additional fields, which should be returned in full.
80+
result = result.return_value
7781
return _ToolReturn(result=result)
7882
except ApprovalRequired:
7983
return _ApprovalRequired()

pydantic_ai_slim/pydantic_ai/messages.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -617,16 +617,13 @@ def __init__(
617617
MultiModalContent = ImageUrl | AudioUrl | DocumentUrl | VideoUrl | BinaryContent
618618
UserContent: TypeAlias = str | MultiModalContent
619619

620-
# TODO (DouweM): thinking about variance
621620
ReturnValueT = TypeVar('ReturnValueT', default=Any, covariant=True)
622621

623622

624623
@dataclass(repr=False)
625624
class Return(Generic[ReturnValueT]):
626625
"""TODO (DouweM): Docstring."""
627626

628-
# TODO (DouweM): Find a better name, or get rid of this entirely?
629-
630627
return_value: ReturnValueT
631628

632629
__repr__ = _utils.dataclasses_no_defaults_repr
@@ -1793,11 +1790,14 @@ class BuiltinToolResultEvent:
17931790
class CustomEvent(Generic[EventDataT]):
17941791
"""An event indicating the result of a function tool call."""
17951792

1796-
data: EventDataT
1793+
data: EventDataT = None
17971794
"""The data of the custom event."""
17981795

17991796
_: KW_ONLY
18001797

1798+
name: str | None = None
1799+
"""The name of the custom event."""
1800+
18011801
tool_call_id: str | None = None
18021802
"""The tool call ID, if any, that this event is associated with."""
18031803

pydantic_ai_slim/pydantic_ai/output.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,12 @@
6969

7070
TextOutputFunc = TypeAliasType(
7171
'TextOutputFunc',
72-
Callable[[RunContext, str], Awaitable[T_co] | T_co] | Callable[[str], Awaitable[T_co] | T_co],
72+
Callable[[RunContext, str], AsyncIterator[Return[T_co] | Any]]
73+
| Callable[[RunContext, str], Awaitable[T_co]]
74+
| Callable[[RunContext, str], T_co]
75+
| Callable[[str], AsyncIterator[Return[T_co] | Any]]
76+
| Callable[[str], Awaitable[T_co]]
77+
| Callable[[str], T_co],
7378
type_params=(T_co,),
7479
)
7580
"""Definition of a function that will be called to process the model's plain text output. The function must take a single string argument.

pydantic_ai_slim/pydantic_ai/result.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,6 @@ async def stream_text(self, *, delta: bool = False, debounce_by: float | None =
116116
yield text
117117
else:
118118
async for text in self._stream_response_text(delta=False, debounce_by=debounce_by):
119-
# TODO (DouweM): What if there's an output function?
120119
for validator in self._output_validators:
121120
text = await validator.validate(text, replace(self._run_ctx, partial_output=True))
122121
yield text
@@ -171,7 +170,6 @@ async def validate_response_output(
171170
raise exceptions.UnexpectedModelBehavior( # pragma: no cover
172171
f'Invalid response, unable to find tool call for {output_tool_name!r}'
173172
)
174-
# TODO (DouweM): This could call an output function that yields custom events, but we're not in an event stream here?
175173
return await self._tool_manager.handle_call(
176174
tool_call, allow_partial=allow_partial, wrap_validation_errors=False
177175
)
@@ -193,7 +191,6 @@ async def validate_response_output(
193191
# not part of the final result output, so we reset the accumulated text
194192
text = ''
195193

196-
# TODO (DouweM): This could call an output function that yields custom events, but we're not in an event stream here?
197194
result_data = await text_processor.process(
198195
text, self._run_ctx, allow_partial=allow_partial, wrap_validation_errors=False
199196
)

pydantic_ai_slim/pydantic_ai/toolsets/function.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -339,7 +339,7 @@ async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[
339339
max_retries=max_retries,
340340
args_validator=tool.function_schema.validator,
341341
call_func=tool.function_schema.call,
342-
is_async=tool.function_schema.is_async,
342+
is_async=tool.function_schema.is_async or tool.function_schema.is_async_iterator,
343343
)
344344
return tools
345345

pydantic_ai_slim/pydantic_ai/ui/ag_ui/_event_stream.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
try:
3333
from ag_ui.core import (
3434
BaseEvent,
35+
CustomEvent as AGUICustomEvent,
3536
EventType,
3637
RunAgentInput,
3738
RunErrorEvent,
@@ -239,3 +240,8 @@ async def handle_function_tool_result(self, event: FunctionToolResultEvent) -> A
239240
async def handle_custom_event(self, event: CustomEvent) -> AsyncIterator[BaseEvent]:
240241
if isinstance(event.data, BaseEvent):
241242
yield event.data
243+
elif event.name:
244+
data = event.data
245+
if event.tool_call_id:
246+
data = {'tool_call_id': event.tool_call_id, 'data': data}
247+
yield AGUICustomEvent(name=event.name, value=data)

pydantic_ai_slim/pydantic_ai/ui/vercel_ai/_event_stream.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from .request_types import RequestData
2929
from .response_types import (
3030
BaseChunk,
31+
DataChunk,
3132
DoneChunk,
3233
ErrorChunk,
3334
FileChunk,
@@ -190,3 +191,8 @@ async def handle_function_tool_result(self, event: FunctionToolResultEvent) -> A
190191
async def handle_custom_event(self, event: CustomEvent) -> AsyncIterator[BaseChunk]:
191192
if isinstance(event.data, BaseChunk):
192193
yield event.data
194+
elif event.name:
195+
data = event.data
196+
if event.tool_call_id:
197+
data = {'tool_call_id': event.tool_call_id, 'data': data}
198+
yield DataChunk(type=f'data-{event.name}', data=data)

0 commit comments

Comments
 (0)