Skip to content

Commit cf2b9b6

Browse files
committed
Tweaks, add validate_response_output_sync
1 parent 584ae6a commit cf2b9b6

File tree

4 files changed

+49
-21
lines changed

4 files changed

+49
-21
lines changed

pydantic_ai_slim/pydantic_ai/_utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,15 @@ def sync_anext(iterator: Iterator[T]) -> T:
234234
raise StopAsyncIteration() from e
235235

236236

237+
def sync_async_iterator(async_iter: AsyncIterator[T]) -> Iterator[T]:
238+
loop = get_event_loop()
239+
while True:
240+
try:
241+
yield loop.run_until_complete(anext(async_iter))
242+
except StopAsyncIteration:
243+
break
244+
245+
237246
def now_utc() -> datetime:
238247
return datetime.now(tz=timezone.utc)
239248

pydantic_ai_slim/pydantic_ai/agent/abstract.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from typing_extensions import Self, TypeIs, TypeVar
1313

1414
from pydantic_graph import End
15-
from pydantic_graph._utils import get_event_loop
1615

1716
from .. import (
1817
_agent_graph,
@@ -335,7 +334,7 @@ def run_sync(
335334
if infer_name and self.name is None:
336335
self._infer_name(inspect.currentframe())
337336

338-
return get_event_loop().run_until_complete(
337+
return _utils.get_event_loop().run_until_complete(
339338
self.run(
340339
user_prompt,
341340
output_type=output_type,
@@ -685,8 +684,7 @@ def main():
685684
The result of the run.
686685
"""
687686
if infer_name and self.name is None:
688-
if frame := inspect.currentframe(): # pragma: no branch
689-
self._infer_name(frame)
687+
self._infer_name(inspect.currentframe())
690688

691689
async def _consume_stream():
692690
async with self.run_stream(
@@ -706,7 +704,7 @@ async def _consume_stream():
706704
) as stream_result:
707705
yield stream_result
708706

709-
return get_event_loop().run_until_complete(_consume_stream().__anext__())
707+
return _utils.get_event_loop().run_until_complete(anext(_consume_stream()))
710708

711709
@overload
712710
def run_stream_events(
@@ -1344,6 +1342,6 @@ def to_cli_sync(
13441342
agent.to_cli_sync(prog_name='assistant')
13451343
```
13461344
"""
1347-
return get_event_loop().run_until_complete(
1345+
return _utils.get_event_loop().run_until_complete(
13481346
self.to_cli(deps=deps, prog_name=prog_name, message_history=message_history)
13491347
)

pydantic_ai_slim/pydantic_ai/result.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -438,8 +438,7 @@ def stream_output_sync(self, *, debounce_by: float | None = 0.1) -> Iterator[Out
438438
Returns:
439439
An iterable of the response data.
440440
"""
441-
async_stream = self.stream_output(debounce_by=debounce_by)
442-
yield from _blocking_async_iterator(async_stream)
441+
return _utils.sync_async_iterator(self.stream_output(debounce_by=debounce_by))
443442

444443
async def stream_text(self, *, delta: bool = False, debounce_by: float | None = 0.1) -> AsyncIterator[str]:
445444
"""Stream the text result as an async iterable.
@@ -485,8 +484,7 @@ def stream_text_sync(self, *, delta: bool = False, debounce_by: float | None = 0
485484
Debouncing is particularly important for long structured responses to reduce the overhead of
486485
performing validation as each token is received.
487486
"""
488-
async_stream = self.stream_text(delta=delta, debounce_by=debounce_by)
489-
yield from _blocking_async_iterator(async_stream)
487+
return _utils.sync_async_iterator(self.stream_text(delta=delta, debounce_by=debounce_by))
490488

491489
@deprecated('`StreamedRunResult.stream_structured` is deprecated, use `stream_responses` instead.')
492490
async def stream_structured(
@@ -539,8 +537,7 @@ def stream_responses_sync(
539537
Returns:
540538
An iterable of the structured response message and whether that is the last message.
541539
"""
542-
async_stream = self.stream_responses(debounce_by=debounce_by)
543-
yield from _blocking_async_iterator(async_stream)
540+
return _utils.sync_async_iterator(self.stream_responses(debounce_by=debounce_by))
544541

545542
async def get_output(self) -> OutputDataT:
546543
"""Stream the whole response, validate and return it."""
@@ -614,6 +611,18 @@ async def validate_response_output(
614611
else:
615612
raise ValueError('No stream response or run result provided') # pragma: no cover
616613

614+
def validate_response_output_sync(
615+
self, message: _messages.ModelResponse, *, allow_partial: bool = False
616+
) -> OutputDataT:
617+
"""Validate a structured result message.
618+
619+
This is a convenience method that wraps [`validate_response_output()`][pydantic_ai.result.StreamedRunResult.validate_response_output] with `loop.run_until_complete(...)`.
620+
You therefore can't use this method inside async code or if there's an active event loop.
621+
"""
622+
return _utils.get_event_loop().run_until_complete(
623+
self.validate_response_output(message, allow_partial=allow_partial)
624+
)
625+
617626
async def _marked_completed(self, message: _messages.ModelResponse | None = None) -> None:
618627
self.is_complete = True
619628
if message is not None:
@@ -638,15 +647,6 @@ class FinalResult(Generic[OutputDataT]):
638647
__repr__ = _utils.dataclasses_no_defaults_repr
639648

640649

641-
def _blocking_async_iterator(async_iter: AsyncIterator[T]) -> Iterator[T]:
642-
loop = _utils.get_event_loop()
643-
while True:
644-
try:
645-
yield loop.run_until_complete(async_iter.__anext__())
646-
except StopAsyncIteration:
647-
break
648-
649-
650650
def _get_usage_checking_stream_response(
651651
stream_response: models.StreamedResponse,
652652
limits: UsageLimits | None,

tests/test_streaming.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2085,3 +2085,24 @@ async def ret_a(x: str) -> str:
20852085
AgentRunResultEvent(result=AgentRunResult(output='{"ret_a":"a-apple"}')),
20862086
]
20872087
)
2088+
2089+
2090+
def test_structured_response_sync_validation():
2091+
async def text_stream(_messages: list[ModelMessage], agent_info: AgentInfo) -> AsyncIterator[DeltaToolCalls]:
2092+
assert agent_info.output_tools is not None
2093+
assert len(agent_info.output_tools) == 1
2094+
name = agent_info.output_tools[0].name
2095+
json_data = json.dumps({'response': [1, 2, 3, 4]})
2096+
yield {0: DeltaToolCall(name=name)}
2097+
yield {0: DeltaToolCall(json_args=json_data[:15])}
2098+
yield {0: DeltaToolCall(json_args=json_data[15:])}
2099+
2100+
agent = Agent(FunctionModel(stream_function=text_stream), output_type=list[int])
2101+
2102+
chunks: list[list[int]] = []
2103+
result = agent.run_stream_sync('')
2104+
for structured_response, last in result.stream_responses_sync(debounce_by=None):
2105+
response_data = result.validate_response_output_sync(structured_response, allow_partial=not last)
2106+
chunks.append(response_data)
2107+
2108+
assert chunks == snapshot([[1], [1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4]])

0 commit comments

Comments
 (0)