Skip to content

Commit 8623cb9

Browse files
committed
add lazy implementation
1 parent 64586b4 commit 8623cb9

File tree

2 files changed

+24
-18
lines changed

2 files changed

+24
-18
lines changed

pydantic_ai_slim/pydantic_ai/agent/abstract.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -586,12 +586,13 @@ def run_stream_sync(
586586
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
587587
builtin_tools: Sequence[AbstractBuiltinTool] | None = None,
588588
event_stream_handler: EventStreamHandler[AgentDepsT] | None = None,
589-
) -> Iterator[result.CollectedRunResult[AgentDepsT, Any]]:
590-
"""Run the agent with a user prompt in collected streaming mode.
589+
) -> Iterator[result.SyncStreamedRunResult[AgentDepsT, Any]]:
590+
"""Run the agent with a user prompt in sync streaming mode.
591591
592592
This method builds an internal agent graph (using system prompts, tools and output schemas) and then
593593
runs the graph until the model produces output matching the `output_type`, for example text or structured data.
594-
At this point, a streaming run result object is collected and -- once this output has completed streaming -- you can iterate over the complete output, message history, and usage.
594+
At this point, a streaming run result object is yielded from which you can stream the output as it comes in,
595+
and -- once this output has completed streaming -- get the complete output, message history, and usage.
595596
596597
As this method will consider the first output matching the `output_type` to be the final output,
597598
it will stop running the agent graph and will not execute any tool calls made by the model after this "final" output.
@@ -647,7 +648,7 @@ def main():
647648
event_stream_handler=event_stream_handler,
648649
)
649650
async_result = get_event_loop().run_until_complete(async_cm.__aenter__())
650-
yield result.CollectedRunResult.from_streamed_result(async_result) # type: ignore[reportReturnType]
651+
yield result.SyncStreamedRunResult.from_streamed_result(async_result) # type: ignore[reportReturnType]
651652

652653
@overload
653654
def run_stream_events(

pydantic_ai_slim/pydantic_ai/result.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -546,14 +546,14 @@ async def _marked_completed(self, message: _messages.ModelResponse | None = None
546546

547547

548548
@dataclass(init=False)
549-
class CollectedRunResult(StreamedRunResult[AgentDepsT, OutputDataT]):
550-
"""Provides a synchronous API over 'StreamedRunResult' by eagerly loading the stream."""
549+
class SyncStreamedRunResult(StreamedRunResult[AgentDepsT, OutputDataT]):
550+
"""Provides a synchronous API over 'StreamedRunResult'."""
551551

552552
@classmethod
553553
def from_streamed_result(
554554
cls, streamed_run_result: StreamedRunResult[AgentDepsT, OutputDataT]
555-
) -> CollectedRunResult[AgentDepsT, OutputDataT]:
556-
"""Create a CollectedRunResult from an existing StreamedRunResult."""
555+
) -> SyncStreamedRunResult[AgentDepsT, OutputDataT]:
556+
"""Create a 'SyncStreamedRunResult' from an existing 'StreamedRunResult'."""
557557
instance = cls.__new__(cls)
558558

559559
instance._all_messages = streamed_run_result._all_messages
@@ -565,14 +565,19 @@ def from_streamed_result(
565565

566566
return instance
567567

568-
def _collect_async_iterator(self, async_iter: AsyncIterator[T]) -> list[T]:
569-
async def collect():
570-
return [item async for item in async_iter]
568+
def _lazy_async_iterator(self, async_iter: AsyncIterator[T]) -> Iterator[T]:
569+
"""Lazily yield items from async iterator as they're requested."""
570+
loop = get_event_loop()
571571

572-
return get_event_loop().run_until_complete(collect())
572+
while True:
573+
try:
574+
item = loop.run_until_complete(async_iter.__anext__())
575+
yield item
576+
except StopAsyncIteration:
577+
break
573578

574579
def stream_output(self, *, debounce_by: float | None = 0.1) -> Iterator[OutputDataT]: # type: ignore[reportIncompatibleMethodOverride]
575-
"""Collect and stream the output as an iterable.
580+
"""Stream the output as an iterable.
576581
577582
The pydantic validator for structured data will be called in
578583
[partial mode](https://docs.pydantic.dev/dev/concepts/experimental/#partial-validation)
@@ -587,10 +592,10 @@ def stream_output(self, *, debounce_by: float | None = 0.1) -> Iterator[OutputDa
587592
An iterable of the response data.
588593
"""
589594
async_stream = super().stream_output(debounce_by=debounce_by)
590-
yield from self._collect_async_iterator(async_stream)
595+
yield from self._lazy_async_iterator(async_stream)
591596

592597
def stream_text(self, *, delta: bool = False, debounce_by: float | None = 0.1) -> Iterator[str]: # type: ignore[reportIncompatibleMethodOverride]
593-
"""Collect and stream the text result as an iterable.
598+
"""Stream the text result as an iterable.
594599
595600
!!! note
596601
Result validators will NOT be called on the text result if `delta=True`.
@@ -603,10 +608,10 @@ def stream_text(self, *, delta: bool = False, debounce_by: float | None = 0.1) -
603608
performing validation as each token is received.
604609
"""
605610
async_stream = super().stream_text(delta=delta, debounce_by=debounce_by)
606-
yield from self._collect_async_iterator(async_stream)
611+
yield from self._lazy_async_iterator(async_stream)
607612

608613
def stream_responses(self, *, debounce_by: float | None = 0.1) -> Iterator[tuple[_messages.ModelResponse, bool]]: # type: ignore[reportIncompatibleMethodOverride]
609-
"""Collect and stream the response as an iterable of Structured LLM Messages.
614+
"""Stream the response as an iterable of Structured LLM Messages.
610615
611616
Args:
612617
debounce_by: by how much (if at all) to debounce/group the response chunks by. `None` means no debouncing.
@@ -617,7 +622,7 @@ def stream_responses(self, *, debounce_by: float | None = 0.1) -> Iterator[tuple
617622
An iterable of the structured response message and whether that is the last message.
618623
"""
619624
async_stream = super().stream_responses(debounce_by=debounce_by)
620-
yield from self._collect_async_iterator(async_stream)
625+
yield from self._lazy_async_iterator(async_stream)
621626

622627
def get_output(self) -> OutputDataT: # type: ignore[reportIncompatibleMethodOverride]
623628
"""Stream the whole response, validate and return it."""

0 commit comments

Comments
 (0)