Skip to content

Commit 9725ec2

Browse files
committed
Progress
1 parent b006441 commit 9725ec2

File tree

11 files changed

+256
-39
lines changed

11 files changed

+256
-39
lines changed

pydantic_ai_slim/pydantic_ai/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
BinaryImage,
4343
BuiltinToolCallPart,
4444
BuiltinToolReturnPart,
45+
CustomEvent,
4546
DocumentFormat,
4647
DocumentMediaType,
4748
DocumentUrl,
@@ -68,6 +69,7 @@
6869
PartEndEvent,
6970
PartStartEvent,
7071
RetryPromptPart,
72+
Return,
7173
SystemPromptPart,
7274
TextPart,
7375
TextPartDelta,
@@ -141,6 +143,7 @@
141143
'BinaryContent',
142144
'BuiltinToolCallPart',
143145
'BuiltinToolReturnPart',
146+
'CustomEvent',
144147
'DocumentFormat',
145148
'DocumentMediaType',
146149
'DocumentUrl',
@@ -168,6 +171,7 @@
168171
'PartEndEvent',
169172
'PartStartEvent',
170173
'RetryPromptPart',
174+
'Return',
171175
'SystemPromptPart',
172176
'TextPart',
173177
'TextPartDelta',

pydantic_ai_slim/pydantic_ai/_agent_graph.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -438,7 +438,7 @@ async def stream(
438438
_output_schema=ctx.deps.output_schema,
439439
_model_request_parameters=model_request_parameters,
440440
_output_validators=ctx.deps.output_validators,
441-
_run_ctx=build_run_context(ctx),
441+
_run_ctx=run_context,
442442
_usage_limits=ctx.deps.usage_limits,
443443
_tool_manager=ctx.deps.tool_manager,
444444
)
@@ -646,6 +646,7 @@ async def _run_stream() -> AsyncIterator[_messages.HandleResponseEvent]: # noqa
646646

647647
if text_processor := output_schema.text_processor:
648648
if text:
649+
# TODO (DouweM): This could call an output function that yields custom events, but we're not in an event stream here?
649650
self._next_node = await self._handle_text_response(ctx, text, text_processor)
650651
return
651652
alternatives.insert(0, 'return text')
@@ -1072,7 +1073,8 @@ async def _call_tool(
10721073
except ToolRetryError as e:
10731074
return e.tool_retry, None
10741075

1075-
if isinstance(tool_result, _messages.ToolReturn):
1076+
tool_return: _messages.Return | None = None
1077+
if isinstance(tool_result, _messages.Return):
10761078
tool_return = tool_result
10771079
else:
10781080
result_is_list = isinstance(tool_result, list)
@@ -1083,8 +1085,8 @@ async def _call_tool(
10831085
for content in contents:
10841086
if isinstance(content, _messages.ToolReturn):
10851087
raise exceptions.UserError(
1086-
f'The return value of tool {tool_call.tool_name!r} contains invalid nested `ToolReturn` objects. '
1087-
f'`ToolReturn` should be used directly.'
1088+
f'The return value of tool {tool_call.tool_name!r} contains invalid nested `Return` objects. '
1089+
f'`Return` should be used directly.'
10881090
)
10891091
elif isinstance(content, _messages.MultiModalContent):
10901092
identifier = content.identifier
@@ -1116,10 +1118,13 @@ async def _call_tool(
11161118
tool_name=tool_call.tool_name,
11171119
tool_call_id=tool_call.tool_call_id,
11181120
content=tool_return.return_value, # type: ignore
1119-
metadata=tool_return.metadata,
11201121
)
11211122

1122-
return return_part, tool_return.content or None
1123+
if isinstance(tool_return, _messages.ToolReturn):
1124+
return_part.metadata = tool_return.metadata
1125+
return return_part, tool_return.content or None
1126+
else:
1127+
return return_part, None
11231128

11241129

11251130
@dataclasses.dataclass

pydantic_ai_slim/pydantic_ai/_function_schema.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from __future__ import annotations as _annotations
77

88
from collections.abc import Awaitable, Callable
9-
from dataclasses import dataclass, field
9+
from dataclasses import dataclass, field, replace
1010
from inspect import Parameter, signature
1111
from typing import TYPE_CHECKING, Any, Concatenate, cast, get_origin
1212

@@ -19,7 +19,7 @@
1919
from pydantic_core import SchemaValidator, core_schema
2020
from typing_extensions import ParamSpec, TypeIs, TypeVar
2121

22-
from pydantic_ai.messages import CustomEvent, ToolReturn
22+
from pydantic_ai.messages import CustomEvent, Return
2323

2424
from ._griffe import doc_descriptions
2525
from ._run_context import RunContext
@@ -61,18 +61,26 @@ async def call(self, args_dict: dict[str, Any], ctx: RunContext[Any]) -> Any:
6161
'RunContext.event_stream needs to be set to use FunctionSchema.call with async iterators'
6262
)
6363

64-
async for event_payload in self.function(*args, **kwargs):
65-
if isinstance(event_payload, ToolReturn):
66-
return event_payload
64+
return_value: Return | None = None
65+
async for event_data in self.function(*args, **kwargs):
66+
if return_value is not None:
67+
from .exceptions import UserError
6768

68-
event = (
69-
cast(CustomEvent, event_payload)
70-
if isinstance(event_payload, CustomEvent)
71-
else CustomEvent(payload=event_payload)
72-
)
69+
raise UserError('Return value must be the last value yielded by the function')
70+
71+
if isinstance(event_data, Return):
72+
return_value = cast(Return, event_data)
73+
continue
74+
75+
if isinstance(event_data, CustomEvent):
76+
event = cast(CustomEvent, event_data)
77+
if ctx.tool_call_id:
78+
event = replace(event, tool_call_id=ctx.tool_call_id)
79+
else:
80+
event = CustomEvent(data=event_data, tool_call_id=ctx.tool_call_id)
7381
await ctx.event_stream.send(event)
74-
# TODO (DouweM): Raise if events are yielded after ToolReturn
75-
return None
82+
83+
return return_value
7684
elif self.is_async:
7785
function = cast(Callable[[Any], Awaitable[str]], self.function)
7886
return await function(*args, **kwargs)

pydantic_ai_slim/pydantic_ai/_output.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,8 @@ async def execute_traced_output_function(
134134
) as span:
135135
try:
136136
output = await function_schema.call(args, run_context)
137+
if isinstance(output, _messages.Return):
138+
output = cast(_messages.Return[Any], output).return_value
137139
except ModelRetry as r:
138140
if wrap_validation_errors:
139141
m = _messages.RetryPromptPart(

pydantic_ai_slim/pydantic_ai/messages.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
if TYPE_CHECKING:
2424
from .models.instrumented import InstrumentationSettings
2525

26-
EventPayloadT = TypeVar('EventPayloadT', default=Any)
26+
EventDataT = TypeVar('EventPayloadT', default=Any)
2727

2828

2929
AudioMediaType: TypeAlias = Literal['audio/wav', 'audio/mpeg', 'audio/ogg', 'audio/flac', 'audio/aiff', 'audio/aac']
@@ -617,9 +617,23 @@ def __init__(
617617
MultiModalContent = ImageUrl | AudioUrl | DocumentUrl | VideoUrl | BinaryContent
618618
UserContent: TypeAlias = str | MultiModalContent
619619

620+
# TODO (DouweM): thinking about variance
621+
ReturnValueT = TypeVar('ReturnValueT', default=Any, covariant=True)
622+
623+
624+
@dataclass(repr=False)
625+
class Return(Generic[ReturnValueT]):
626+
"""TODO (DouweM): Docstring."""
627+
628+
# TODO (DouweM): Find a better name, or get rid of this entirely?
629+
630+
return_value: ReturnValueT
631+
632+
__repr__ = _utils.dataclasses_no_defaults_repr
633+
620634

621635
@dataclass(repr=False)
622-
class ToolReturn:
636+
class ToolReturn(Return[ReturnValueT]):
623637
"""A structured return value for tools that need to provide both a return value and custom content to the model.
624638
625639
This class allows tools to return complex responses that include:
@@ -628,9 +642,6 @@ class ToolReturn:
628642
- Optional metadata for application use
629643
"""
630644

631-
return_value: Any
632-
"""The return value to be used in the tool response."""
633-
634645
_: KW_ONLY
635646

636647
content: str | Sequence[UserContent] | None = None
@@ -1779,19 +1790,16 @@ class BuiltinToolResultEvent:
17791790

17801791

17811792
@dataclass(repr=False)
1782-
class CustomEvent(Generic[EventPayloadT]):
1793+
class CustomEvent(Generic[EventDataT]):
17831794
"""An event indicating the result of a function tool call."""
17841795

1785-
payload: EventPayloadT
1786-
"""The payload of the custom event."""
1796+
data: EventDataT
1797+
"""The data of the custom event."""
17871798

17881799
_: KW_ONLY
17891800

1790-
name: str | None = None
1791-
"""The optional name of the custom event."""
1792-
1793-
id: str | None = None
1794-
"""The optional ID of the custom event."""
1801+
tool_call_id: str | None = None
1802+
"""The tool call ID, if any, that this event is associated with."""
17951803

17961804
event_kind: Literal['custom'] = 'custom'
17971805
"""Event type identifier, used as a discriminator."""

pydantic_ai_slim/pydantic_ai/output.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from collections.abc import Awaitable, Callable, Sequence
3+
from collections.abc import AsyncIterator, Awaitable, Callable, Sequence
44
from dataclasses import dataclass
55
from typing import Any, Generic, Literal
66

@@ -11,7 +11,7 @@
1111

1212
from . import _utils, exceptions
1313
from ._json_schema import InlineDefsJsonSchemaTransformer
14-
from .messages import ToolCallPart
14+
from .messages import Return, ToolCallPart
1515
from .tools import DeferredToolRequests, ObjectJsonSchema, RunContext, ToolDefinition
1616

1717
__all__ = (
@@ -45,10 +45,19 @@
4545
StructuredOutputMode = Literal['tool', 'native', 'prompted']
4646
"""Output modes that can be used for structured output. Used by ModelProfile.default_structured_output_mode"""
4747

48-
49-
OutputTypeOrFunction = TypeAliasType(
50-
'OutputTypeOrFunction', type[T_co] | Callable[..., Awaitable[T_co] | T_co], type_params=(T_co,)
48+
OutputFunction = TypeAliasType(
49+
'OutputFunction',
50+
Callable[..., AsyncIterator[Return[T_co] | Any]] | Callable[..., Awaitable[T_co]] | Callable[..., T_co],
51+
type_params=(T_co,),
5152
)
53+
"""Definition of an output function.
54+
55+
You should not need to import or use this type directly.
56+
57+
See [output docs](../output.md) for more information.
58+
"""
59+
60+
OutputTypeOrFunction = TypeAliasType('OutputTypeOrFunction', OutputFunction[T_co] | type[T_co], type_params=(T_co,))
5261
"""Definition of an output type or function.
5362
5463
You should not need to import or use this type directly.

pydantic_ai_slim/pydantic_ai/result.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,7 @@ async def validate_response_output(
170170
raise exceptions.UnexpectedModelBehavior( # pragma: no cover
171171
f'Invalid response, unable to find tool call for {output_tool_name!r}'
172172
)
173+
# TODO (DouweM): This could call an output function that yields custom events, but we're not in an event stream here?
173174
return await self._tool_manager.handle_call(
174175
tool_call, allow_partial=allow_partial, wrap_validation_errors=False
175176
)
@@ -191,6 +192,7 @@ async def validate_response_output(
191192
# not part of the final result output, so we reset the accumulated text
192193
text = ''
193194

195+
# TODO (DouweM): This could call an output function that yields custom events, but we're not in an event stream here?
194196
result_data = await text_processor.process(
195197
text, self._run_ctx, allow_partial=allow_partial, wrap_validation_errors=False
196198
)

pydantic_ai_slim/pydantic_ai/ui/_event_stream.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
BuiltinToolCallPart,
1616
BuiltinToolResultEvent, # pyright: ignore[reportDeprecated]
1717
BuiltinToolReturnPart,
18+
CustomEvent,
1819
FilePart,
1920
FinalResultEvent,
2021
FunctionToolCallEvent,
@@ -229,7 +230,7 @@ async def _turn_to(self, to_turn: Literal['request', 'response'] | None) -> Asyn
229230
async for e in self.before_response():
230231
yield e
231232

232-
async def handle_event(self, event: NativeEvent) -> AsyncIterator[EventT]:
233+
async def handle_event(self, event: NativeEvent) -> AsyncIterator[EventT]: # noqa: C901
233234
"""Transform a Pydantic AI event into one or more protocol-specific events.
234235
235236
This method dispatches to specific `handle_*` methods based on event type:
@@ -240,6 +241,7 @@ async def handle_event(self, event: NativeEvent) -> AsyncIterator[EventT]:
240241
- [`FinalResultEvent`][pydantic_ai.messages.FinalResultEvent] -> `handle_final_result`
241242
- [`FunctionToolCallEvent`][pydantic_ai.messages.FunctionToolCallEvent] -> `handle_function_tool_call`
242243
- [`FunctionToolResultEvent`][pydantic_ai.messages.FunctionToolResultEvent] -> `handle_function_tool_result`
244+
- [`CustomEvent`][pydantic_ai.messages.CustomEvent] -> `handle_custom_event`
243245
- [`AgentRunResultEvent`][pydantic_ai.run.AgentRunResultEvent] -> `handle_run_result`
244246
245247
Subclasses are encouraged to override the individual `handle_*` methods rather than this one.
@@ -264,6 +266,9 @@ async def handle_event(self, event: NativeEvent) -> AsyncIterator[EventT]:
264266
case FunctionToolResultEvent():
265267
async for e in self.handle_function_tool_result(event):
266268
yield e
269+
case CustomEvent():
270+
async for e in self.handle_custom_event(event):
271+
yield e
267272
case AgentRunResultEvent():
268273
async for e in self.handle_run_result(event):
269274
yield e
@@ -581,6 +586,15 @@ async def handle_function_tool_result(self, event: FunctionToolResultEvent) -> A
581586
return # pragma: no cover
582587
yield # Make this an async generator
583588

589+
async def handle_custom_event(self, event: CustomEvent) -> AsyncIterator[EventT]:
590+
"""Handle a `CustomEvent`.
591+
592+
Args:
593+
event: The custom event.
594+
"""
595+
return
596+
yield # Make this an async generator
597+
584598
async def handle_run_result(self, event: AgentRunResultEvent) -> AsyncIterator[EventT]:
585599
"""Handle an `AgentRunResultEvent`.
586600

pydantic_ai_slim/pydantic_ai/ui/ag_ui/_event_stream.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from ...messages import (
1515
BuiltinToolCallPart,
1616
BuiltinToolReturnPart,
17+
CustomEvent,
1718
FunctionToolResultEvent,
1819
RetryPromptPart,
1920
TextPart,
@@ -234,3 +235,7 @@ async def handle_function_tool_result(self, event: FunctionToolResultEvent) -> A
234235
for item in possible_event: # type: ignore[reportUnknownMemberType]
235236
if isinstance(item, BaseEvent): # pragma: no branch
236237
yield item
238+
239+
async def handle_custom_event(self, event: CustomEvent) -> AsyncIterator[BaseEvent]:
240+
if isinstance(event.data, BaseEvent):
241+
yield event.data

pydantic_ai_slim/pydantic_ai/ui/vercel_ai/_event_stream.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from ...messages import (
1212
BuiltinToolCallPart,
1313
BuiltinToolReturnPart,
14+
CustomEvent,
1415
FilePart,
1516
FunctionToolResultEvent,
1617
RetryPromptPart,
@@ -185,3 +186,7 @@ async def handle_function_tool_result(self, event: FunctionToolResultEvent) -> A
185186
yield ToolOutputAvailableChunk(tool_call_id=result.tool_call_id, output=result.content)
186187

187188
# ToolCallResultEvent.content may hold user parts (e.g. text, images) that Vercel AI does not currently have events for
189+
190+
async def handle_custom_event(self, event: CustomEvent) -> AsyncIterator[BaseChunk]:
191+
if isinstance(event.data, BaseChunk):
192+
yield event.data

0 commit comments

Comments
 (0)