Skip to content

Commit 0391387

Browse files
committed
Refactor to use anyio
1 parent be30528 commit 0391387

File tree

6 files changed

+253
-170
lines changed

6 files changed

+253
-170
lines changed

pydantic_ai_slim/pydantic_ai/_utils.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ async def group_by_temporal(
147147
aiterable: The async iterable to group.
148148
soft_max_interval: Maximum interval over which to group items, this should avoid a trickle of items causing
149149
a group to never be yielded. It's a soft max in the sense that once we're over this time, we yield items
150-
as soon as `aiter.__anext__()` returns. If `None`, no grouping/debouncing is performed
150+
as soon as `anext(aiter)` returns. If `None`, no grouping/debouncing is performed
151151
152152
Returns:
153153
A context manager usable as an async iterable of lists of items produced by the input async iterable.
@@ -171,7 +171,7 @@ async def async_iter_groups() -> AsyncIterator[list[T]]:
171171
buffer: list[T] = []
172172
group_start_time = time.monotonic()
173173

174-
aiterator = aiterable.__aiter__()
174+
aiterator = aiter(aiterable)
175175
while True:
176176
if group_start_time is None:
177177
# group hasn't started, we just wait for the maximum interval
@@ -182,9 +182,9 @@ async def async_iter_groups() -> AsyncIterator[list[T]]:
182182

183183
# if there's no current task, we get the next one
184184
if task is None:
185-
# aiter.__anext__() returns an Awaitable[T], not a Coroutine which asyncio.create_task expects
185+
# anext(aiter) returns an Awaitable[T], not a Coroutine which asyncio.create_task expects
186186
# so far, this doesn't seem to be a problem
187-
task = asyncio.create_task(aiterator.__anext__()) # pyright: ignore[reportArgumentType]
187+
task = asyncio.create_task(anext(aiterator)) # pyright: ignore[reportArgumentType]
188188

189189
# we use asyncio.wait to avoid cancelling the coroutine if it's not done
190190
done, _ = await asyncio.wait((task,), timeout=wait_time)
@@ -284,10 +284,10 @@ async def peek(self) -> T | Unset:
284284

285285
# Otherwise, we need to fetch the next item from the underlying iterator.
286286
if self._source_iter is None:
287-
self._source_iter = self._source.__aiter__()
287+
self._source_iter = aiter(self._source)
288288

289289
try:
290-
self._buffer = await self._source_iter.__anext__()
290+
self._buffer = await anext(self._source_iter)
291291
except StopAsyncIteration:
292292
self._exhausted = True
293293
return UNSET
@@ -318,10 +318,10 @@ async def __anext__(self) -> T:
318318

319319
# Otherwise, fetch the next item from the source.
320320
if self._source_iter is None:
321-
self._source_iter = self._source.__aiter__()
321+
self._source_iter = aiter(self._source)
322322

323323
try:
324-
return await self._source_iter.__anext__()
324+
return await anext(self._source_iter)
325325
except StopAsyncIteration:
326326
self._exhausted = True
327327
raise

pydantic_ai_slim/pydantic_ai/run.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ async def __anext__(
145145
self,
146146
) -> _agent_graph.AgentNode[AgentDepsT, OutputDataT] | End[FinalResult[OutputDataT]]:
147147
"""Advance to the next node automatically based on the last returned node."""
148-
task = await self._graph_run.__anext__()
148+
task = await anext(self._graph_run)
149149
return self._task_to_node(task)
150150

151151
def _task_to_node(

pydantic_graph/pydantic_graph/beta/decision.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from __future__ import annotations
99

1010
import inspect
11-
from collections.abc import Callable, Iterable, Sequence
11+
from collections.abc import AsyncIterable, Callable, Iterable, Sequence
1212
from dataclasses import dataclass
1313
from typing import TYPE_CHECKING, Any, Generic, get_origin
1414

@@ -232,7 +232,8 @@ def transform(
232232
)
233233

234234
def map(
235-
self: DecisionBranchBuilder[StateT, DepsT, Iterable[T], SourceT, HandledT],
235+
self: DecisionBranchBuilder[StateT, DepsT, Iterable[T], SourceT, HandledT]
236+
| DecisionBranchBuilder[StateT, DepsT, AsyncIterable[T], SourceT, HandledT],
236237
*,
237238
fork_id: str | None = None,
238239
downstream_join_id: str | None = None,

0 commit comments

Comments
 (0)