77from __future__ import annotations
88
99from abc import ABC , abstractmethod
10- from collections .abc import AsyncIterator , Awaitable , Callable , Mapping , Sequence
10+ from collections .abc import AsyncIterator , Mapping , Sequence
1111from dataclasses import Field , dataclass , replace
1212from functools import cached_property
1313from http import HTTPStatus
1717 ClassVar ,
1818 Generic ,
1919 Protocol ,
20- TypeAlias ,
2120 TypeVar ,
2221 runtime_checkable ,
2322)
2423
2524from pydantic import BaseModel , ValidationError
2625
27- from .. import DeferredToolRequests , DeferredToolResults , _utils
28- from ..agent import AbstractAgent , AgentDepsT , AgentRunResult
26+ from .. import DeferredToolRequests , DeferredToolResults
27+ from ..agent import AbstractAgent , AgentDepsT
2928from ..builtin_tools import AbstractBuiltinTool
3029from ..exceptions import UserError
3130from ..messages import ModelMessage
3231from ..models import KnownModelName , Model
33- from ..output import OutputSpec
32+ from ..output import OutputDataT , OutputSpec
3433from ..settings import ModelSettings
3534from ..toolsets import AbstractToolset
3635from ..usage import RunUsage , UsageLimits
37- from .event_stream import BaseEventStream , SourceEvent
36+ from .event_stream import BaseEventStream , OnCompleteFunc , SourceEvent
3837
3938if TYPE_CHECKING :
4039 from starlette .requests import Request
5554EventT = TypeVar ('EventT' )
5655"""Type variable for protocol-specific event types."""
5756
58- OnCompleteFunc : TypeAlias = Callable [[AgentRunResult [Any ]], None ] | Callable [[AgentRunResult [Any ]], Awaitable [None ]]
59- """Callback function type that receives the `AgentRunResult` of the completed run. Can be sync or async."""
60-
6157
6258# State management types
6359
@@ -111,10 +107,10 @@ class StateDeps(Generic[StateT]):
111107
112108
113109@dataclass
114- class BaseAdapter (ABC , Generic [RunRequestT , MessageT , EventT , AgentDepsT ]):
110+ class BaseAdapter (ABC , Generic [RunRequestT , MessageT , EventT , AgentDepsT , OutputDataT ]):
115111 """TODO (DouwM): Docstring."""
116112
117- agent : AbstractAgent [AgentDepsT ]
113+ agent : AbstractAgent [AgentDepsT , OutputDataT ]
118114 """The Pydantic AI agent to run."""
119115
120116 request : RunRequestT
@@ -134,7 +130,7 @@ def load_messages(cls, messages: Sequence[MessageT]) -> list[ModelMessage]:
134130
135131 @property
136132 @abstractmethod
137- def event_stream (self ) -> BaseEventStream [RunRequestT , EventT , AgentDepsT ]:
133+ def event_stream (self ) -> BaseEventStream [RunRequestT , EventT , AgentDepsT , OutputDataT ]:
138134 """Create an event stream for the adapter."""
139135 raise NotImplementedError
140136
@@ -178,24 +174,17 @@ def encode_stream(self, stream: AsyncIterator[EventT], accept: str | None = None
178174 async def process_stream (
179175 self ,
180176 stream : AsyncIterator [SourceEvent ],
181- on_complete : OnCompleteFunc | None = None ,
177+ on_complete : OnCompleteFunc [ EventT ] | None = None ,
182178 ) -> AsyncIterator [EventT ]:
183179 """Process a stream of events and return a stream of events.
184180
185181 Args:
186182 stream: The stream of events to process.
187183 on_complete: Optional callback function called when the agent run completes successfully.
188184 """
189- event_stream = self .event_stream
190- async for event in event_stream .handle_stream (stream ):
185+ async for event in self .event_stream .handle_stream (stream , on_complete = on_complete ):
191186 yield event
192187
193- if (result := event_stream .result ) and on_complete is not None :
194- if _utils .is_async_callable (on_complete ):
195- await on_complete (result )
196- else :
197- await _utils .run_in_executor (on_complete , result )
198-
199188 async def run_stream (
200189 self ,
201190 * ,
@@ -210,7 +199,7 @@ async def run_stream(
210199 infer_name : bool = True ,
211200 toolsets : Sequence [AbstractToolset [AgentDepsT ]] | None = None ,
212201 builtin_tools : Sequence [AbstractBuiltinTool ] | None = None ,
213- on_complete : OnCompleteFunc | None = None ,
202+ on_complete : OnCompleteFunc [ EventT ] | None = None ,
214203 ) -> AsyncIterator [EventT ]:
215204 """Run the agent with the AG-UI run input and stream AG-UI protocol events.
216205
@@ -298,7 +287,7 @@ async def stream_response(self, stream: AsyncIterator[EventT], accept: str | Non
298287 @classmethod
299288 async def dispatch_request (
300289 cls ,
301- agent : AbstractAgent [AgentDepsT , Any ],
290+ agent : AbstractAgent [AgentDepsT , OutputDataT ],
302291 request : Request ,
303292 * ,
304293 message_history : Sequence [ModelMessage ] | None = None ,
@@ -312,7 +301,7 @@ async def dispatch_request(
312301 infer_name : bool = True ,
313302 toolsets : Sequence [AbstractToolset [AgentDepsT ]] | None = None ,
314303 builtin_tools : Sequence [AbstractBuiltinTool ] | None = None ,
315- on_complete : OnCompleteFunc | None = None ,
304+ on_complete : OnCompleteFunc [ EventT ] | None = None ,
316305 ) -> Response :
317306 """Handle an AG-UI request and return a streaming response.
318307
0 commit comments