Skip to content

Commit 5544992

Browse files
committed
Set Content-Type header on StreamingResponse
1 parent aebf039 commit 5544992

File tree

8 files changed

+88
-44
lines changed

8 files changed

+88
-44
lines changed

pydantic_ai_slim/pydantic_ai/ui/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,12 @@
77
from __future__ import annotations
88

99
from .adapter import BaseAdapter, OnCompleteFunc, StateDeps, StateHandler
10-
from .event_stream import BaseEventStream
10+
from .event_stream import SSE_CONTENT_TYPE, BaseEventStream
1111

1212
__all__ = [
1313
'BaseAdapter',
1414
'BaseEventStream',
15+
'SSE_CONTENT_TYPE',
1516
'StateDeps',
1617
'StateHandler',
1718
'OnCompleteFunc',

pydantic_ai_slim/pydantic_ai/ui/adapter.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -128,10 +128,18 @@ def load_messages(cls, messages: Sequence[MessageT]) -> list[ModelMessage]:
128128
"""Load messages from the request and return the loaded messages."""
129129
raise NotImplementedError
130130

131-
@property
132131
@abstractmethod
133-
def event_stream(self) -> BaseEventStream[RunRequestT, EventT, AgentDepsT, OutputDataT]:
134-
"""Create an event stream for the adapter."""
132+
def build_event_stream(
133+
self, accept: str | None = None
134+
) -> BaseEventStream[RunRequestT, EventT, AgentDepsT, OutputDataT]:
135+
"""Create an event stream for the adapter.
136+
137+
Args:
138+
accept: The accept header value.
139+
140+
Returns:
141+
The event stream.
142+
"""
135143
raise NotImplementedError
136144

137145
@cached_property
@@ -167,9 +175,9 @@ def encode_stream(self, stream: AsyncIterator[EventT], accept: str | None = None
167175
168176
Args:
169177
stream: The stream of events to encode.
170-
accept: The accept header value for encoding format.
178+
accept: The accept header value.
171179
"""
172-
return self.event_stream.encode_stream(stream, accept)
180+
return self.build_event_stream(accept).encode_stream(stream)
173181

174182
async def process_stream(
175183
self,
@@ -182,7 +190,7 @@ async def process_stream(
182190
stream: The stream of events to process.
183191
on_complete: Optional callback function called when the agent run completes successfully.
184192
"""
185-
async for event in self.event_stream.handle_stream(stream, on_complete=on_complete):
193+
async for event in self.build_event_stream().handle_stream(stream, on_complete=on_complete):
186194
yield event
187195

188196
async def run_stream(
@@ -266,7 +274,7 @@ async def stream_response(self, stream: AsyncIterator[EventT], accept: str | Non
266274
267275
Args:
268276
stream: The stream of events to encode.
269-
accept: The accept header value for encoding format.
277+
accept: The accept header value.
270278
"""
271279
try:
272280
from starlette.responses import StreamingResponse
@@ -276,12 +284,11 @@ async def stream_response(self, stream: AsyncIterator[EventT], accept: str | Non
276284
'you can use the `ui` optional group — `pip install "pydantic-ai-slim[ui]"`'
277285
) from e
278286

287+
event_stream = self.build_event_stream(accept)
279288
return StreamingResponse(
280-
self.encode_stream(
281-
stream,
282-
accept=accept,
283-
),
289+
event_stream.encode_stream(stream),
284290
headers=self.response_headers,
291+
media_type=event_stream.content_type,
285292
)
286293

287294
@classmethod

pydantic_ai_slim/pydantic_ai/ui/ag_ui/_adapter.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -101,10 +101,18 @@ async def validate_request(cls, request: Request) -> RunAgentInput:
101101
"""Validate the request and return the validated request."""
102102
return RunAgentInput.model_validate(await request.json())
103103

104-
@property
105-
def event_stream(self) -> BaseEventStream[RunAgentInput, BaseEvent, AgentDepsT, OutputDataT]:
106-
"""Create an event stream for the adapter."""
107-
return AGUIEventStream(self.request)
104+
def build_event_stream(
105+
self, accept: str | None = None
106+
) -> BaseEventStream[RunAgentInput, BaseEvent, AgentDepsT, OutputDataT]:
107+
"""Create an event stream for the adapter.
108+
109+
Args:
110+
accept: The accept header value.
111+
112+
Returns:
113+
The event stream.
114+
"""
115+
return AGUIEventStream(self.request, accept=accept)
108116

109117
@cached_property
110118
def toolset(self) -> AbstractToolset[AgentDepsT] | None:

pydantic_ai_slim/pydantic_ai/ui/ag_ui/_event_stream.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
)
2727
from ...output import OutputDataT
2828
from ...tools import AgentDepsT
29-
from .. import BaseEventStream
29+
from .. import SSE_CONTENT_TYPE, BaseEventStream
3030

3131
try:
3232
from ag_ui.core import (
@@ -64,10 +64,6 @@
6464
'RunFinishedEvent',
6565
]
6666

67-
SSE_CONTENT_TYPE: Final[str] = 'text/event-stream'
68-
"""Content type header value for Server-Sent Events (SSE)."""
69-
70-
7167
BUILTIN_TOOL_CALL_ID_PREFIX: Final[str] = 'pyd_ai_builtin'
7268

7369

@@ -79,18 +75,29 @@ class AGUIEventStream(BaseEventStream[RunAgentInput, BaseEvent, AgentDepsT, Outp
7975
_builtin_tool_call_ids: dict[str, str] = field(default_factory=dict)
8076
_error: bool = False
8177

82-
def encode_event(self, event: BaseEvent, accept: str | None = None) -> str:
78+
@property
79+
def _event_encoder(self) -> EventEncoder:
80+
return EventEncoder(accept=self.accept or SSE_CONTENT_TYPE)
81+
82+
@property
83+
def content_type(self) -> str:
84+
"""Get the content type for the event stream, compatible with the accept header value.
85+
86+
Args:
87+
accept: The accept header value.
88+
"""
89+
return self._event_encoder.get_content_type()
90+
91+
def encode_event(self, event: BaseEvent) -> str:
8392
"""Encode an AG-UI event as SSE.
8493
8594
Args:
8695
event: The AG-UI event to encode.
87-
accept: The accept header value for encoding format.
8896
8997
Returns:
9098
The SSE-formatted string.
9199
"""
92-
encoder = EventEncoder(accept=accept or SSE_CONTENT_TYPE)
93-
return encoder.encode(event)
100+
return self._event_encoder.encode(event)
94101

95102
async def before_stream(self) -> AsyncIterator[BaseEvent]:
96103
"""Yield events before agent streaming starts."""

pydantic_ai_slim/pydantic_ai/ui/event_stream.py

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,9 @@
4444
'BaseEventStream',
4545
]
4646

47+
SSE_CONTENT_TYPE = 'text/event-stream'
48+
"""Content type header value for Server-Sent Events (SSE)."""
49+
4750
EventT = TypeVar('EventT')
4851
"""Type variable for protocol-specific event types."""
4952

@@ -65,12 +68,15 @@ class BaseEventStream(ABC, Generic[RunRequestT, EventT, AgentDepsT, OutputDataT]
6568
"""TODO (DouwM): Docstring."""
6669

6770
request: RunRequestT
68-
result: AgentRunResult[OutputDataT] | None = None
71+
72+
accept: str | None = None
73+
"""TODO (DouweM): Docstring"""
6974

7075
message_id: str = field(default_factory=lambda: str(uuid4()))
7176

7277
_turn: Literal['request', 'response'] | None = None
7378

79+
_result: AgentRunResult[OutputDataT] | None = None
7480
_final_result_event: FinalResultEvent | None = None
7581

7682
def new_message_id(self) -> str:
@@ -82,25 +88,35 @@ def new_message_id(self) -> str:
8288
self.message_id = str(uuid4())
8389
return self.message_id
8490

91+
@property
92+
def content_type(self) -> str:
93+
"""Get the content type for the event stream, compatible with the accept header value.
94+
95+
By default, this returns the SSE content type (`text/event-stream`).
96+
If a subclass supports other types as well, it should consider `self.accept` in `encode_event` and return the resulting content type here.
97+
98+
Args:
99+
accept: The accept header value.
100+
"""
101+
return SSE_CONTENT_TYPE
102+
85103
@abstractmethod
86-
def encode_event(self, event: EventT, accept: str | None = None) -> str:
104+
def encode_event(self, event: EventT) -> str:
87105
"""Encode an event as a string.
88106
89107
Args:
90108
event: The event to encode.
91-
accept: The accept header value for encoding format.
92109
"""
93110
raise NotImplementedError
94111

95-
async def encode_stream(self, stream: AsyncIterator[EventT], accept: str | None = None) -> AsyncIterator[str]:
112+
async def encode_stream(self, stream: AsyncIterator[EventT]) -> AsyncIterator[str]:
96113
"""Encode a stream of events as SSE strings.
97114
98115
Args:
99116
stream: The stream of events to encode.
100-
accept: The accept header value for encoding format.
101117
"""
102118
async for event in stream:
103-
yield self.encode_event(event, accept)
119+
yield self.encode_event(event)
104120

105121
async def handle_stream( # noqa: C901
106122
self, stream: AsyncIterator[SourceEvent], on_complete: OnCompleteFunc[EventT] | None = None
@@ -147,19 +163,20 @@ async def handle_stream( # noqa: C901
147163
async for e in self.handle_function_tool_result(output_tool_result_event):
148164
yield e
149165

150-
self.result = cast(AgentRunResult[OutputDataT], event.result)
166+
result = cast(AgentRunResult[OutputDataT], event.result)
167+
self._result = result
151168

152169
async for e in self._turn_to(None):
153170
yield e
154171

155172
if on_complete is not None:
156173
if inspect.isasyncgenfunction(on_complete):
157-
async for e in on_complete(self.result):
174+
async for e in on_complete(result):
158175
yield e
159176
elif _utils.is_async_callable(on_complete):
160-
await on_complete(self.result)
177+
await on_complete(result)
161178
else:
162-
await _utils.run_in_executor(on_complete, self.result)
179+
await _utils.run_in_executor(on_complete, result)
163180
elif isinstance(event, FinalResultEvent):
164181
self._final_result_event = event
165182

pydantic_ai_slim/pydantic_ai/ui/vercel_ai/_adapter.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,10 @@ async def validate_request(cls, request: Request) -> RequestData:
7272
"""Validate a Vercel AI request."""
7373
return request_data_ta.validate_json(await request.body())
7474

75-
@property
76-
def event_stream(self) -> BaseEventStream[RequestData, BaseChunk, AgentDepsT, OutputDataT]:
77-
return VercelAIEventStream(self.request)
75+
def build_event_stream(
76+
self, accept: str | None = None
77+
) -> BaseEventStream[RequestData, BaseChunk, AgentDepsT, OutputDataT]:
78+
return VercelAIEventStream(self.request, accept=accept)
7879

7980
@property
8081
def response_headers(self) -> Mapping[str, str] | None:

pydantic_ai_slim/pydantic_ai/ui/vercel_ai/_event_stream.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ class VercelAIEventStream(BaseEventStream[RequestData, BaseChunk, AgentDepsT, Ou
6161

6262
_step_started: bool = False
6363

64-
def encode_event(self, event: BaseChunk, accept: str | None = None) -> str:
64+
def encode_event(self, event: BaseChunk) -> str:
6565
if isinstance(event, DoneChunk):
6666
return 'data: [DONE]\n\n'
6767
return f'data: {event.model_dump_json(by_alias=True, exclude_none=True)}\n\n'

tests/test_ui.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -92,9 +92,8 @@ async def validate_request(cls, request: Request) -> UIRequest:
9292
def load_messages(cls, messages: Sequence[ModelMessage]) -> list[ModelMessage]:
9393
return list(messages)
9494

95-
@property
96-
def event_stream(self) -> UIEventStream[AgentDepsT, OutputDataT]:
97-
return UIEventStream[AgentDepsT, OutputDataT](self.request)
95+
def build_event_stream(self, accept: str | None = None) -> UIEventStream[AgentDepsT, OutputDataT]:
96+
return UIEventStream[AgentDepsT, OutputDataT](self.request, accept=accept)
9897

9998
@cached_property
10099
def messages(self) -> list[ModelMessage]:
@@ -115,7 +114,7 @@ def response_headers(self) -> dict[str, str]:
115114

116115
@dataclass(kw_only=True)
117116
class UIEventStream(BaseEventStream[UIRequest, str, AgentDepsT, OutputDataT]):
118-
def encode_event(self, event: str, accept: str | None = None) -> str:
117+
def encode_event(self, event: str) -> str:
119118
return event
120119

121120
async def handle_event(self, event: SourceEvent) -> AsyncIterator[str]:
@@ -629,7 +628,11 @@ async def send(data: MutableMapping[str, Any]) -> None:
629628

630629
assert chunks == snapshot(
631630
[
632-
{'type': 'http.response.start', 'status': 200, 'headers': [(b'x-test', b'test')]},
631+
{
632+
'type': 'http.response.start',
633+
'status': 200,
634+
'headers': [(b'x-test', b'test'), (b'content-type', b'text/event-stream; charset=utf-8')],
635+
},
633636
{'type': 'http.response.body', 'body': b'<stream>', 'more_body': True},
634637
{'type': 'http.response.body', 'body': b'<response>', 'more_body': True},
635638
{'type': 'http.response.body', 'body': b'<text follows_text=False>', 'more_body': True},

0 commit comments

Comments
 (0)