Skip to content

Commit 963ac22

Browse files
committed
Almost there
1 parent 3b94453 commit 963ac22

File tree

6 files changed

+38
-16
lines changed

6 files changed

+38
-16
lines changed

pydantic_ai_slim/pydantic_ai/_agent_graph.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -915,8 +915,10 @@ async def handle_call_or_result(
915915
yield event
916916

917917
else:
918+
# TODO(Marcelo): We need to replace this with `anyio.as_completed()` from
919+
# https://github.com/agronholm/anyio/pull/890.
918920
tasks = [
919-
asyncio.create_task(
921+
asyncio.create_task( # noqa: TID251
920922
_call_tool(tool_manager, call, tool_call_results.get(call.tool_call_id), usage_limits),
921923
name=call.tool_name,
922924
)
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from __future__ import annotations as _annotations
2+
3+
from collections.abc import Awaitable
4+
from typing import TypeVar, cast
5+
6+
import anyio
7+
8+
T = TypeVar('T')
9+
10+
11+
async def gather(*awaitables: Awaitable[T]) -> list[T]:
12+
"""Run multiple awaitables concurrently using an AnyIO task group."""
13+
# We initialize the list, so we can insert the results in the correct order.
14+
results: list[T] = cast(list[T], [None] * len(awaitables))
15+
16+
async def run_and_store(coro: Awaitable[T], index: int) -> None:
17+
results[index] = await coro
18+
19+
async with anyio.create_task_group() as tg:
20+
for i, c in enumerate(awaitables):
21+
tg.start_soon(run_and_store, c, i)
22+
23+
return results

pydantic_ai_slim/pydantic_ai/_cli.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,16 @@
11
from __future__ import annotations as _annotations
22

33
import argparse
4-
import asyncio
54
import importlib
65
import os
76
import sys
8-
from asyncio import CancelledError
97
from collections.abc import Sequence
108
from contextlib import ExitStack
119
from datetime import datetime, timezone
1210
from pathlib import Path
1311
from typing import Any, cast
1412

13+
import anyio
1514
from typing_inspection.introspection import get_literal_values
1615

1716
from . import __version__
@@ -209,13 +208,13 @@ def cli( # noqa: C901
209208

210209
if prompt := cast(str, args.prompt):
211210
try:
212-
asyncio.run(ask_agent(agent, prompt, stream, console, code_theme))
211+
anyio.run(ask_agent, agent, prompt, stream, console, code_theme)
213212
except KeyboardInterrupt:
214213
pass
215214
return 0
216215

217216
try:
218-
return asyncio.run(run_chat(stream, agent, console, code_theme, prog_name))
217+
return anyio.run(run_chat, stream, agent, console, code_theme, prog_name)
219218
except KeyboardInterrupt: # pragma: no cover
220219
return 0
221220

@@ -256,7 +255,7 @@ async def run_chat(
256255
else:
257256
try:
258257
messages = await ask_agent(agent, text, stream, console, code_theme, deps, messages)
259-
except CancelledError: # pragma: no cover
258+
except anyio.get_cancelled_exc_class(): # pragma: no cover
260259
console.print('[dim]Interrupted[/dim]')
261260
except Exception as e: # pragma: no cover
262261
cause = getattr(e, '__cause__', None)

pydantic_ai_slim/pydantic_ai/_utils.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,11 @@
1414
from types import GenericAlias
1515
from typing import TYPE_CHECKING, Any, Generic, TypeAlias, TypeGuard, TypeVar, get_args, get_origin, overload
1616

17+
import anyio
1718
from anyio.to_thread import run_sync
1819
from pydantic import BaseModel, TypeAdapter
1920
from pydantic.json_schema import JsonSchemaValue
20-
from typing_extensions import (
21-
ParamSpec,
22-
TypeIs,
23-
is_typeddict,
24-
)
21+
from typing_extensions import ParamSpec, TypeIs, is_typeddict
2522
from typing_inspection import typing_objects
2623
from typing_inspection.introspection import is_union_origin
2724

@@ -97,8 +94,6 @@ class Some(Generic[T]):
9794
class Unset:
9895
"""A singleton to represent an unset value."""
9996

100-
pass
101-
10297

10398
UNSET = Unset()
10499

@@ -197,7 +192,8 @@ async def async_iter_groups() -> AsyncIterator[list[T]]:
197192
group_start_time = None
198193

199194
try:
200-
yield async_iter_groups()
195+
async with anyio.create_task_group() as tg:
196+
yield async_iter_groups()
201197
finally: # pragma: no cover
202198
# after iteration if a tasks still exists, cancel it, this will only happen if an error occurred
203199
if task:

pydantic_ai_slim/pydantic_ai/toolsets/combined.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from __future__ import annotations
22

3-
import asyncio
43
import functools
54
from collections.abc import Callable, Sequence
65
from contextlib import AsyncExitStack
@@ -10,6 +9,7 @@
109
import anyio
1110
from typing_extensions import Self
1211

12+
from .. import _anyio
1313
from .._run_context import AgentDepsT, RunContext
1414
from ..exceptions import UserError
1515
from .abstract import AbstractToolset, ToolsetTool
@@ -66,7 +66,7 @@ async def __aexit__(self, *args: Any) -> bool | None:
6666
self._exit_stack = None
6767

6868
async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[AgentDepsT]]:
69-
toolsets_tools = await asyncio.gather(*(toolset.get_tools(ctx) for toolset in self.toolsets))
69+
toolsets_tools = await _anyio.gather(*(toolset.get_tools(ctx) for toolset in self.toolsets))
7070
all_tools: dict[str, ToolsetTool[AgentDepsT]] = {}
7171

7272
for toolset, tools in zip(self.toolsets, toolsets_tools):

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,8 @@ convention = "google"
170170

171171
[tool.ruff.lint.flake8-tidy-imports.banned-api]
172172
"typing.TypedDict".msg = "Use typing_extensions.TypedDict instead."
173+
"asyncio.gather".msg = "Use pydantic_ai._anyio.run_all instead."
174+
"asyncio.create_task".msg = "Use `anyio.create_task_group` instead."
173175

174176
[tool.ruff.format]
175177
# don't format python in docstrings, pytest-examples takes care of it

0 commit comments

Comments
 (0)