Skip to content

Commit 0b1dea3

Browse files
committed
tests
1 parent 42e39e2 commit 0b1dea3

File tree

9 files changed

+560
-131
lines changed

9 files changed

+560
-131
lines changed

pydantic_ai_slim/pydantic_ai/ag_ui.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from .usage import RunUsage, UsageLimits
2121

2222
try:
23+
from ag_ui.core import BaseEvent
2324
from ag_ui.core.types import RunAgentInput
2425

2526
from .ui import OnCompleteFunc, StateDeps, StateHandler
@@ -171,7 +172,7 @@ async def handle_ag_ui_request(
171172
usage: RunUsage | None = None,
172173
infer_name: bool = True,
173174
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
174-
on_complete: OnCompleteFunc | None = None,
175+
on_complete: OnCompleteFunc[BaseEvent] | None = None,
175176
) -> Response:
176177
"""Handle an AG-UI request by running the agent and returning a streaming response.
177178
@@ -226,7 +227,7 @@ async def run_ag_ui(
226227
usage: RunUsage | None = None,
227228
infer_name: bool = True,
228229
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
229-
on_complete: OnCompleteFunc | None = None,
230+
on_complete: OnCompleteFunc[BaseEvent] | None = None,
230231
) -> AsyncIterator[str]:
231232
"""Run the agent with the AG-UI run input and stream AG-UI protocol events.
232233

pydantic_ai_slim/pydantic_ai/ui/adapter.py

Lines changed: 13 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from __future__ import annotations
88

99
from abc import ABC, abstractmethod
10-
from collections.abc import AsyncIterator, Awaitable, Callable, Mapping, Sequence
10+
from collections.abc import AsyncIterator, Mapping, Sequence
1111
from dataclasses import Field, dataclass, replace
1212
from functools import cached_property
1313
from http import HTTPStatus
@@ -17,24 +17,23 @@
1717
ClassVar,
1818
Generic,
1919
Protocol,
20-
TypeAlias,
2120
TypeVar,
2221
runtime_checkable,
2322
)
2423

2524
from pydantic import BaseModel, ValidationError
2625

27-
from .. import DeferredToolRequests, DeferredToolResults, _utils
28-
from ..agent import AbstractAgent, AgentDepsT, AgentRunResult
26+
from .. import DeferredToolRequests, DeferredToolResults
27+
from ..agent import AbstractAgent, AgentDepsT
2928
from ..builtin_tools import AbstractBuiltinTool
3029
from ..exceptions import UserError
3130
from ..messages import ModelMessage
3231
from ..models import KnownModelName, Model
33-
from ..output import OutputSpec
32+
from ..output import OutputDataT, OutputSpec
3433
from ..settings import ModelSettings
3534
from ..toolsets import AbstractToolset
3635
from ..usage import RunUsage, UsageLimits
37-
from .event_stream import BaseEventStream, SourceEvent
36+
from .event_stream import BaseEventStream, OnCompleteFunc, SourceEvent
3837

3938
if TYPE_CHECKING:
4039
from starlette.requests import Request
@@ -55,9 +54,6 @@
5554
EventT = TypeVar('EventT')
5655
"""Type variable for protocol-specific event types."""
5756

58-
OnCompleteFunc: TypeAlias = Callable[[AgentRunResult[Any]], None] | Callable[[AgentRunResult[Any]], Awaitable[None]]
59-
"""Callback function type that receives the `AgentRunResult` of the completed run. Can be sync or async."""
60-
6157

6258
# State management types
6359

@@ -111,10 +107,10 @@ class StateDeps(Generic[StateT]):
111107

112108

113109
@dataclass
114-
class BaseAdapter(ABC, Generic[RunRequestT, MessageT, EventT, AgentDepsT]):
110+
class BaseAdapter(ABC, Generic[RunRequestT, MessageT, EventT, AgentDepsT, OutputDataT]):
115111
"""TODO (DouwM): Docstring."""
116112

117-
agent: AbstractAgent[AgentDepsT]
113+
agent: AbstractAgent[AgentDepsT, OutputDataT]
118114
"""The Pydantic AI agent to run."""
119115

120116
request: RunRequestT
@@ -134,7 +130,7 @@ def load_messages(cls, messages: Sequence[MessageT]) -> list[ModelMessage]:
134130

135131
@property
136132
@abstractmethod
137-
def event_stream(self) -> BaseEventStream[RunRequestT, EventT, AgentDepsT]:
133+
def event_stream(self) -> BaseEventStream[RunRequestT, EventT, AgentDepsT, OutputDataT]:
138134
"""Create an event stream for the adapter."""
139135
raise NotImplementedError
140136

@@ -178,24 +174,17 @@ def encode_stream(self, stream: AsyncIterator[EventT], accept: str | None = None
178174
async def process_stream(
179175
self,
180176
stream: AsyncIterator[SourceEvent],
181-
on_complete: OnCompleteFunc | None = None,
177+
on_complete: OnCompleteFunc[EventT] | None = None,
182178
) -> AsyncIterator[EventT]:
183179
"""Process a stream of events and return a stream of events.
184180
185181
Args:
186182
stream: The stream of events to process.
187183
on_complete: Optional callback function called when the agent run completes successfully.
188184
"""
189-
event_stream = self.event_stream
190-
async for event in event_stream.handle_stream(stream):
185+
async for event in self.event_stream.handle_stream(stream, on_complete=on_complete):
191186
yield event
192187

193-
if (result := event_stream.result) and on_complete is not None:
194-
if _utils.is_async_callable(on_complete):
195-
await on_complete(result)
196-
else:
197-
await _utils.run_in_executor(on_complete, result)
198-
199188
async def run_stream(
200189
self,
201190
*,
@@ -210,7 +199,7 @@ async def run_stream(
210199
infer_name: bool = True,
211200
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
212201
builtin_tools: Sequence[AbstractBuiltinTool] | None = None,
213-
on_complete: OnCompleteFunc | None = None,
202+
on_complete: OnCompleteFunc[EventT] | None = None,
214203
) -> AsyncIterator[EventT]:
215204
"""Run the agent with the AG-UI run input and stream AG-UI protocol events.
216205
@@ -298,7 +287,7 @@ async def stream_response(self, stream: AsyncIterator[EventT], accept: str | Non
298287
@classmethod
299288
async def dispatch_request(
300289
cls,
301-
agent: AbstractAgent[AgentDepsT, Any],
290+
agent: AbstractAgent[AgentDepsT, OutputDataT],
302291
request: Request,
303292
*,
304293
message_history: Sequence[ModelMessage] | None = None,
@@ -312,7 +301,7 @@ async def dispatch_request(
312301
infer_name: bool = True,
313302
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
314303
builtin_tools: Sequence[AbstractBuiltinTool] | None = None,
315-
on_complete: OnCompleteFunc | None = None,
304+
on_complete: OnCompleteFunc[EventT] | None = None,
316305
) -> Response:
317306
"""Handle an AG-UI request and return a streaming response.
318307

pydantic_ai_slim/pydantic_ai/ui/ag_ui/_adapter.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
ToolReturnPart,
2626
UserPromptPart,
2727
)
28+
from ...output import OutputDataT
2829
from ...toolsets import AbstractToolset
2930

3031
try:
@@ -92,7 +93,7 @@ def label(self) -> str:
9293
return 'the AG-UI frontend tools' # pragma: no cover
9394

9495

95-
class AGUIAdapter(BaseAdapter[RunAgentInput, Message, BaseEvent, AgentDepsT]):
96+
class AGUIAdapter(BaseAdapter[RunAgentInput, Message, BaseEvent, AgentDepsT, OutputDataT]):
9697
"""TODO (DouwM): Docstring."""
9798

9899
@classmethod
@@ -101,7 +102,7 @@ async def validate_request(cls, request: Request) -> RunAgentInput:
101102
return RunAgentInput.model_validate(await request.json())
102103

103104
@property
104-
def event_stream(self) -> BaseEventStream[RunAgentInput, BaseEvent, AgentDepsT]:
105+
def event_stream(self) -> BaseEventStream[RunAgentInput, BaseEvent, AgentDepsT, OutputDataT]:
105106
"""Create an event stream for the adapter."""
106107
return AGUIEventStream(self.request)
107108

pydantic_ai_slim/pydantic_ai/ui/ag_ui/_event_stream.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import json
1010
from collections.abc import AsyncIterator, Iterable
11+
from dataclasses import dataclass, field
1112
from typing import Final
1213

1314
from ...messages import (
@@ -23,6 +24,7 @@
2324
ToolCallPartDelta,
2425
ToolReturnPart,
2526
)
27+
from ...output import OutputDataT
2628
from ...tools import AgentDepsT
2729
from .. import BaseEventStream
2830

@@ -69,14 +71,13 @@
6971
BUILTIN_TOOL_CALL_ID_PREFIX: Final[str] = 'pyd_ai_builtin'
7072

7173

72-
class AGUIEventStream(BaseEventStream[RunAgentInput, BaseEvent, AgentDepsT]):
74+
@dataclass
75+
class AGUIEventStream(BaseEventStream[RunAgentInput, BaseEvent, AgentDepsT, OutputDataT]):
7376
"""TODO (DouwM): Docstring."""
7477

75-
def __init__(self, request: RunAgentInput) -> None:
76-
"""Initialize AG-UI event stream state."""
77-
super().__init__(request)
78-
self._thinking_text = False
79-
self._builtin_tool_call_ids: dict[str, str] = {}
78+
_thinking_text: bool = False
79+
_builtin_tool_call_ids: dict[str, str] = field(default_factory=dict)
80+
_error: bool = False
8081

8182
def encode_event(self, event: BaseEvent, accept: str | None = None) -> str:
8283
"""Encode an AG-UI event as SSE.
@@ -100,13 +101,15 @@ async def before_stream(self) -> AsyncIterator[BaseEvent]:
100101

101102
async def after_stream(self) -> AsyncIterator[BaseEvent]:
102103
"""Handle an AgentRunResultEvent, cleaning up any pending state."""
103-
yield RunFinishedEvent(
104-
thread_id=self.request.thread_id,
105-
run_id=self.request.run_id,
106-
)
104+
if not self._error:
105+
yield RunFinishedEvent(
106+
thread_id=self.request.thread_id,
107+
run_id=self.request.run_id,
108+
)
107109

108110
async def on_error(self, error: Exception) -> AsyncIterator[BaseEvent]:
109111
"""Handle errors during streaming."""
112+
self._error = True
110113
yield RunErrorEvent(message=str(error))
111114

112115
async def handle_text_start(self, part: TextPart, follows_text: bool = False) -> AsyncIterator[BaseEvent]:

0 commit comments

Comments
 (0)