Skip to content

Commit 6e74a2a

Browse files
committed
add _sync methods to StreamedRunResult
1 parent 6497f63 commit 6e74a2a

File tree

3 files changed

+75
-95
lines changed

3 files changed

+75
-95
lines changed

pydantic_ai_slim/pydantic_ai/agent/abstract.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -598,7 +598,7 @@ def run_stream_sync(
598598
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
599599
builtin_tools: Sequence[AbstractBuiltinTool] | None = None,
600600
event_stream_handler: EventStreamHandler[AgentDepsT] | None = None,
601-
) -> Iterator[result.SyncStreamedRunResult[AgentDepsT, Any]]:
601+
) -> Iterator[result.StreamedRunResult[AgentDepsT, Any]]:
602602
"""Run the agent with a user prompt in sync streaming mode.
603603
604604
This method builds an internal agent graph (using system prompts, tools and output schemas) and then
@@ -659,8 +659,7 @@ def main():
659659
builtin_tools=builtin_tools,
660660
event_stream_handler=event_stream_handler,
661661
)
662-
async_result = get_event_loop().run_until_complete(async_cm.__aenter__())
663-
yield result.SyncStreamedRunResult.from_streamed_result(async_result) # type: ignore[reportReturnType]
662+
yield get_event_loop().run_until_complete(async_cm.__aenter__())
664663

665664
@overload
666665
def run_stream_events(

pydantic_ai_slim/pydantic_ai/result.py

Lines changed: 65 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -410,6 +410,24 @@ async def stream_output(self, *, debounce_by: float | None = 0.1) -> AsyncIterat
410410
else:
411411
raise ValueError('No stream response or run result provided') # pragma: no cover
412412

413+
def stream_output_sync(self, *, debounce_by: float | None = 0.1) -> Iterator[OutputDataT]:
414+
"""Stream the output as an iterable.
415+
416+
The pydantic validator for structured data will be called in
417+
[partial mode](https://docs.pydantic.dev/dev/concepts/experimental/#partial-validation)
418+
on each iteration.
419+
420+
Args:
421+
debounce_by: by how much (if at all) to debounce/group the output chunks by. `None` means no debouncing.
422+
Debouncing is particularly important for long structured outputs to reduce the overhead of
423+
performing validation as each token is received.
424+
425+
Returns:
426+
An iterable of the response data.
427+
"""
428+
async_stream = self.stream_output(debounce_by=debounce_by)
429+
yield from _lazy_async_iterator(async_stream)
430+
413431
async def stream_text(self, *, delta: bool = False, debounce_by: float | None = 0.1) -> AsyncIterator[str]:
414432
"""Stream the text result as an async iterable.
415433
@@ -438,6 +456,22 @@ async def stream_text(self, *, delta: bool = False, debounce_by: float | None =
438456
else:
439457
raise ValueError('No stream response or run result provided') # pragma: no cover
440458

459+
def stream_text_sync(self, *, delta: bool = False, debounce_by: float | None = 0.1) -> Iterator[str]:
460+
"""Stream the text result as an async iterable.
461+
462+
!!! note
463+
Result validators will NOT be called on the text result if `delta=True`.
464+
465+
Args:
466+
delta: if `True`, yield each chunk of text as it is received, if `False` (default), yield the full text
467+
up to the current point.
468+
debounce_by: by how much (if at all) to debounce/group the response chunks by. `None` means no debouncing.
469+
Debouncing is particularly important for long structured responses to reduce the overhead of
470+
performing validation as each token is received.
471+
"""
472+
async_stream = self.stream_text(delta=delta, debounce_by=debounce_by)
473+
yield from _lazy_async_iterator(async_stream)
474+
441475
@deprecated('`StreamedRunResult.stream_structured` is deprecated, use `stream_responses` instead.')
442476
async def stream_structured(
443477
self, *, debounce_by: float | None = 0.1
@@ -473,6 +507,22 @@ async def stream_responses(
473507
else:
474508
raise ValueError('No stream response or run result provided') # pragma: no cover
475509

510+
def stream_responses_sync(
511+
self, *, debounce_by: float | None = 0.1
512+
) -> Iterator[tuple[_messages.ModelResponse, bool]]:
513+
"""Stream the response as an iterable of Structured LLM Messages.
514+
515+
Args:
516+
debounce_by: by how much (if at all) to debounce/group the response chunks by. `None` means no debouncing.
517+
Debouncing is particularly important for long structured responses to reduce the overhead of
518+
performing validation as each token is received.
519+
520+
Returns:
521+
An iterable of the structured response message and whether that is the last message.
522+
"""
523+
async_stream = self.stream_responses(debounce_by=debounce_by)
524+
yield from _lazy_async_iterator(async_stream)
525+
476526
async def get_output(self) -> OutputDataT:
477527
"""Stream the whole response, validate and return it."""
478528
if self._run_result is not None:
@@ -486,6 +536,10 @@ async def get_output(self) -> OutputDataT:
486536
else:
487537
raise ValueError('No stream response or run result provided') # pragma: no cover
488538

539+
def get_output_sync(self) -> OutputDataT:
540+
"""Stream the whole response, validate and return it."""
541+
return get_event_loop().run_until_complete(self.get_output())
542+
489543
@property
490544
def response(self) -> _messages.ModelResponse:
491545
"""Return the current state of the response."""
@@ -545,90 +599,6 @@ async def _marked_completed(self, message: _messages.ModelResponse | None = None
545599
await self._on_complete()
546600

547601

548-
@dataclass(init=False)
549-
class SyncStreamedRunResult(StreamedRunResult[AgentDepsT, OutputDataT]):
550-
"""Provides a synchronous API over 'StreamedRunResult'."""
551-
552-
@classmethod
553-
def from_streamed_result(
554-
cls, streamed_run_result: StreamedRunResult[AgentDepsT, OutputDataT]
555-
) -> SyncStreamedRunResult[AgentDepsT, OutputDataT]:
556-
"""Create a 'SyncStreamedRunResult' from an existing 'StreamedRunResult'."""
557-
instance = cls.__new__(cls)
558-
559-
instance._all_messages = streamed_run_result._all_messages
560-
instance._new_message_index = streamed_run_result._new_message_index
561-
instance._stream_response = streamed_run_result._stream_response
562-
instance._on_complete = streamed_run_result._on_complete
563-
instance._run_result = streamed_run_result._run_result
564-
instance.is_complete = streamed_run_result.is_complete
565-
566-
return instance
567-
568-
def _lazy_async_iterator(self, async_iter: AsyncIterator[T]) -> Iterator[T]:
569-
"""Lazily yield items from async iterator as they're requested."""
570-
loop = get_event_loop()
571-
572-
while True:
573-
try:
574-
item = loop.run_until_complete(async_iter.__anext__())
575-
yield item
576-
except StopAsyncIteration:
577-
break
578-
579-
def stream_output(self, *, debounce_by: float | None = 0.1) -> Iterator[OutputDataT]: # type: ignore[reportIncompatibleMethodOverride]
580-
"""Stream the output as an iterable.
581-
582-
The pydantic validator for structured data will be called in
583-
[partial mode](https://docs.pydantic.dev/dev/concepts/experimental/#partial-validation)
584-
on each iteration.
585-
586-
Args:
587-
debounce_by: by how much (if at all) to debounce/group the output chunks by. `None` means no debouncing.
588-
Debouncing is particularly important for long structured outputs to reduce the overhead of
589-
performing validation as each token is received.
590-
591-
Returns:
592-
An iterable of the response data.
593-
"""
594-
async_stream = super().stream_output(debounce_by=debounce_by)
595-
yield from self._lazy_async_iterator(async_stream)
596-
597-
def stream_text(self, *, delta: bool = False, debounce_by: float | None = 0.1) -> Iterator[str]: # type: ignore[reportIncompatibleMethodOverride]
598-
"""Stream the text result as an iterable.
599-
600-
!!! note
601-
Result validators will NOT be called on the text result if `delta=True`.
602-
603-
Args:
604-
delta: if `True`, yield each chunk of text as it is received, if `False` (default), yield the full text
605-
up to the current point.
606-
debounce_by: by how much (if at all) to debounce/group the response chunks by. `None` means no debouncing.
607-
Debouncing is particularly important for long structured responses to reduce the overhead of
608-
performing validation as each token is received.
609-
"""
610-
async_stream = super().stream_text(delta=delta, debounce_by=debounce_by)
611-
yield from self._lazy_async_iterator(async_stream)
612-
613-
def stream_responses(self, *, debounce_by: float | None = 0.1) -> Iterator[tuple[_messages.ModelResponse, bool]]: # type: ignore[reportIncompatibleMethodOverride]
614-
"""Stream the response as an iterable of Structured LLM Messages.
615-
616-
Args:
617-
debounce_by: by how much (if at all) to debounce/group the response chunks by. `None` means no debouncing.
618-
Debouncing is particularly important for long structured responses to reduce the overhead of
619-
performing validation as each token is received.
620-
621-
Returns:
622-
An iterable of the structured response message and whether that is the last message.
623-
"""
624-
async_stream = super().stream_responses(debounce_by=debounce_by)
625-
yield from self._lazy_async_iterator(async_stream)
626-
627-
def get_output(self) -> OutputDataT: # type: ignore[reportIncompatibleMethodOverride]
628-
"""Stream the whole response, validate and return it."""
629-
return get_event_loop().run_until_complete(super().get_output())
630-
631-
632602
@dataclass(repr=False)
633603
class FinalResult(Generic[OutputDataT]):
634604
"""Marker class storing the final output of an agent run and associated metadata."""
@@ -645,6 +615,17 @@ class FinalResult(Generic[OutputDataT]):
645615
__repr__ = _utils.dataclasses_no_defaults_repr
646616

647617

618+
def _lazy_async_iterator(async_iter: AsyncIterator[T]) -> Iterator[T]:
619+
loop = get_event_loop()
620+
621+
while True:
622+
try:
623+
item = loop.run_until_complete(async_iter.__anext__())
624+
yield item
625+
except StopAsyncIteration:
626+
break
627+
628+
648629
def _get_usage_checking_stream_response(
649630
stream_response: models.StreamedResponse,
650631
limits: UsageLimits | None,

tests/test_streaming.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ async def ret_a(x: str) -> str:
174174
tool_calls=1,
175175
)
176176
)
177-
response = result.get_output()
177+
response = result.get_output_sync()
178178
assert response == snapshot('{"ret_a":"a-apple"}')
179179
assert result.is_complete
180180
assert result.timestamp() == IsNow(tz=timezone.utc)
@@ -389,20 +389,20 @@ def test_streamed_text_stream_sync():
389389

390390
with agent.run_stream_sync('Hello') as result:
391391
# typehint to test (via static typing) that the stream type is correctly inferred
392-
chunks: list[str] = [c for c in result.stream_text()]
392+
chunks: list[str] = [c for c in result.stream_text_sync()]
393393
# one chunk with `stream_text()` due to group_by_temporal
394394
assert chunks == snapshot(['The cat sat on the mat.'])
395395
assert result.is_complete
396396

397397
with agent.run_stream_sync('Hello') as result:
398398
# typehint to test (via static typing) that the stream type is correctly inferred
399-
chunks: list[str] = [c for c in result.stream_output()]
399+
chunks: list[str] = [c for c in result.stream_output_sync()]
400400
# two chunks with `stream()` due to not-final vs. final
401401
assert chunks == snapshot(['The cat sat on the mat.', 'The cat sat on the mat.'])
402402
assert result.is_complete
403403

404404
with agent.run_stream_sync('Hello') as result:
405-
assert [c for c in result.stream_text(debounce_by=None)] == snapshot(
405+
assert [c for c in result.stream_text_sync(debounce_by=None)] == snapshot(
406406
[
407407
'The ',
408408
'The cat ',
@@ -415,20 +415,20 @@ def test_streamed_text_stream_sync():
415415

416416
with agent.run_stream_sync('Hello') as result:
417417
# with stream_text, there is no need to do partial validation, so we only get the final message once:
418-
assert [c for c in result.stream_text(delta=False, debounce_by=None)] == snapshot(
418+
assert [c for c in result.stream_text_sync(delta=False, debounce_by=None)] == snapshot(
419419
['The ', 'The cat ', 'The cat sat ', 'The cat sat on ', 'The cat sat on the ', 'The cat sat on the mat.']
420420
)
421421

422422
with agent.run_stream_sync('Hello') as result:
423-
assert [c for c in result.stream_text(delta=True, debounce_by=None)] == snapshot(
423+
assert [c for c in result.stream_text_sync(delta=True, debounce_by=None)] == snapshot(
424424
['The ', 'cat ', 'sat ', 'on ', 'the ', 'mat.']
425425
)
426426

427427
def upcase(text: str) -> str:
428428
return text.upper()
429429

430430
with agent.run_stream_sync('Hello', output_type=TextOutput(upcase)) as result:
431-
assert [c for c in result.stream_output(debounce_by=None)] == snapshot(
431+
assert [c for c in result.stream_output_sync(debounce_by=None)] == snapshot(
432432
[
433433
'THE ',
434434
'THE CAT ',
@@ -441,7 +441,7 @@ def upcase(text: str) -> str:
441441
)
442442

443443
with agent.run_stream_sync('Hello') as result:
444-
assert [c for c, _is_last in result.stream_responses(debounce_by=None)] == snapshot(
444+
assert [c for c, _is_last in result.stream_responses_sync(debounce_by=None)] == snapshot(
445445
[
446446
ModelResponse(
447447
parts=[TextPart(content='The ')],

0 commit comments

Comments
 (0)