diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index 5f89155092..50eb697b34 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -915,8 +915,10 @@ async def handle_call_or_result( yield event else: + # TODO(Marcelo): We need to replace this with `anyio.as_completed()` from + # https://github.com/agronholm/anyio/pull/890. tasks = [ - asyncio.create_task( + asyncio.create_task( # noqa: TID251 _call_tool(tool_manager, call, tool_call_results.get(call.tool_call_id), usage_limits), name=call.tool_name, ) diff --git a/pydantic_ai_slim/pydantic_ai/_anyio.py b/pydantic_ai_slim/pydantic_ai/_anyio.py new file mode 100644 index 0000000000..833313549f --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/_anyio.py @@ -0,0 +1,23 @@ +from __future__ import annotations as _annotations + +from collections.abc import Awaitable +from typing import TypeVar, cast + +import anyio + +T = TypeVar('T') + + +async def gather(*awaitables: Awaitable[T]) -> list[T]: + """Run multiple awaitables concurrently using an AnyIO task group.""" + # We initialize the list, so we can insert the results in the correct order. + results: list[T] = cast(list[T], [None] * len(awaitables)) + + async def run_and_store(coro: Awaitable[T], index: int) -> None: + results[index] = await coro + + async with anyio.create_task_group() as tg: + for i, c in enumerate(awaitables): + tg.start_soon(run_and_store, c, i) + + return results diff --git a/pydantic_ai_slim/pydantic_ai/_cli.py b/pydantic_ai_slim/pydantic_ai/_cli.py index bc3f2a271b..fce98972b8 100644 --- a/pydantic_ai_slim/pydantic_ai/_cli.py +++ b/pydantic_ai_slim/pydantic_ai/_cli.py @@ -1,17 +1,16 @@ from __future__ import annotations as _annotations import argparse -import asyncio import importlib import os import sys -from asyncio import CancelledError from collections.abc import Sequence from contextlib import ExitStack from datetime import datetime, timezone from pathlib import Path from typing import Any, cast +import anyio from typing_inspection.introspection import get_literal_values from . import __version__ @@ -209,13 +208,13 @@ def cli( # noqa: C901 if prompt := cast(str, args.prompt): try: - asyncio.run(ask_agent(agent, prompt, stream, console, code_theme)) + anyio.run(ask_agent, agent, prompt, stream, console, code_theme) except KeyboardInterrupt: pass return 0 try: - return asyncio.run(run_chat(stream, agent, console, code_theme, prog_name)) + return anyio.run(run_chat, stream, agent, console, code_theme, prog_name) except KeyboardInterrupt: # pragma: no cover return 0 @@ -256,7 +255,7 @@ async def run_chat( else: try: messages = await ask_agent(agent, text, stream, console, code_theme, deps, messages) - except CancelledError: # pragma: no cover + except anyio.get_cancelled_exc_class(): # pragma: no cover console.print('[dim]Interrupted[/dim]') except Exception as e: # pragma: no cover cause = getattr(e, '__cause__', None) diff --git a/pydantic_ai_slim/pydantic_ai/_utils.py b/pydantic_ai_slim/pydantic_ai/_utils.py index 024a568f2e..1eda15c356 100644 --- a/pydantic_ai_slim/pydantic_ai/_utils.py +++ b/pydantic_ai_slim/pydantic_ai/_utils.py @@ -14,14 +14,11 @@ from types import GenericAlias from typing import TYPE_CHECKING, Any, Generic, TypeAlias, TypeGuard, TypeVar, get_args, get_origin, overload +import anyio from anyio.to_thread import run_sync from pydantic import BaseModel, TypeAdapter from pydantic.json_schema import JsonSchemaValue -from typing_extensions import ( - ParamSpec, - TypeIs, - is_typeddict, -) +from typing_extensions import ParamSpec, TypeIs, is_typeddict from typing_inspection import typing_objects from typing_inspection.introspection import is_union_origin @@ -97,8 +94,6 @@ class Some(Generic[T]): class Unset: """A singleton to represent an unset value.""" - pass - UNSET = Unset() @@ -197,7 +192,8 @@ async def async_iter_groups() -> AsyncIterator[list[T]]: group_start_time = None try: - yield async_iter_groups() + async with anyio.create_task_group() as tg: + yield async_iter_groups() finally: # pragma: no cover # after iteration if a tasks still exists, cancel it, this will only happen if an error occurred if task: diff --git a/pydantic_ai_slim/pydantic_ai/agent/__init__.py b/pydantic_ai_slim/pydantic_ai/agent/__init__.py index b70f541262..21fdb93501 100644 --- a/pydantic_ai_slim/pydantic_ai/agent/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/agent/__init__.py @@ -1,15 +1,16 @@ from __future__ import annotations as _annotations import dataclasses +import functools import inspect import json import warnings -from asyncio import Lock from collections.abc import AsyncIterator, Awaitable, Callable, Iterator, Sequence from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager, contextmanager from contextvars import ContextVar from typing import TYPE_CHECKING, Any, ClassVar, cast, overload +import anyio from opentelemetry.trace import NoOpTracer, use_span from pydantic.json_schema import GenerateJsonSchema from typing_extensions import Self, TypeVar, deprecated @@ -153,10 +154,14 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]): _event_stream_handler: EventStreamHandler[AgentDepsT] | None = dataclasses.field(repr=False) - _enter_lock: Lock = dataclasses.field(repr=False) _entered_count: int = dataclasses.field(repr=False) _exit_stack: AsyncExitStack | None = dataclasses.field(repr=False) + @functools.cached_property + def _enter_lock(self) -> anyio.Lock: + # We use a cached_property for this because it seems to work better with temporal... + return anyio.Lock() + @overload def __init__( self, @@ -371,7 +376,6 @@ def __init__( _utils.Option[Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]]] ] = ContextVar('_override_tools', default=None) - self._enter_lock = Lock() self._entered_count = 0 self._exit_stack = None diff --git a/pydantic_ai_slim/pydantic_ai/mcp.py b/pydantic_ai_slim/pydantic_ai/mcp.py index b61f254500..a13a7d21ed 100644 --- a/pydantic_ai_slim/pydantic_ai/mcp.py +++ b/pydantic_ai_slim/pydantic_ai/mcp.py @@ -4,10 +4,9 @@ import functools import warnings from abc import ABC, abstractmethod -from asyncio import Lock from collections.abc import AsyncIterator, Awaitable, Callable, Sequence from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager -from dataclasses import field, replace +from dataclasses import replace from datetime import timedelta from pathlib import Path from typing import Annotated, Any @@ -105,7 +104,6 @@ class MCPServer(AbstractToolset[Any], ABC): _id: str | None - _enter_lock: Lock = field(compare=False) _running_count: int _exit_stack: AsyncExitStack | None @@ -113,6 +111,10 @@ class MCPServer(AbstractToolset[Any], ABC): _read_stream: MemoryObjectReceiveStream[SessionMessage | Exception] _write_stream: MemoryObjectSendStream[SessionMessage] + @functools.cached_property + def _enter_lock(self) -> anyio.Lock: + return anyio.Lock() + def __init__( self, tool_prefix: str | None = None, @@ -144,7 +146,6 @@ def __init__( self.__post_init__() def __post_init__(self): - self._enter_lock = Lock() self._running_count = 0 self._exit_stack = None diff --git a/pydantic_ai_slim/pydantic_ai/providers/google_vertex.py b/pydantic_ai_slim/pydantic_ai/providers/google_vertex.py index 53a0cb48f7..bb86068396 100644 --- a/pydantic_ai_slim/pydantic_ai/providers/google_vertex.py +++ b/pydantic_ai_slim/pydantic_ai/providers/google_vertex.py @@ -1,11 +1,11 @@ from __future__ import annotations as _annotations import functools -from asyncio import Lock from collections.abc import AsyncGenerator, Mapping from pathlib import Path from typing import Literal, overload +import anyio import anyio.to_thread import httpx from typing_extensions import deprecated @@ -119,10 +119,13 @@ def __init__( class _VertexAIAuth(httpx.Auth): """Auth class for Vertex AI API.""" - _refresh_lock: Lock = Lock() - credentials: BaseCredentials | ServiceAccountCredentials | None + @functools.cached_property + def _refresh_lock(self) -> anyio.Lock: + # We use a cached_property for this because it seems to work better with temporal... + return anyio.Lock() + def __init__( self, service_account_file: Path | str | None = None, diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/combined.py b/pydantic_ai_slim/pydantic_ai/toolsets/combined.py index e095e4aa1f..c403632aa9 100644 --- a/pydantic_ai_slim/pydantic_ai/toolsets/combined.py +++ b/pydantic_ai_slim/pydantic_ai/toolsets/combined.py @@ -1,14 +1,15 @@ from __future__ import annotations -import asyncio -from asyncio import Lock +import functools from collections.abc import Callable, Sequence from contextlib import AsyncExitStack from dataclasses import dataclass, field, replace from typing import Any +import anyio from typing_extensions import Self +from .. import _anyio from .._run_context import AgentDepsT, RunContext from ..exceptions import UserError from .abstract import AbstractToolset, ToolsetTool @@ -31,10 +32,14 @@ class CombinedToolset(AbstractToolset[AgentDepsT]): toolsets: Sequence[AbstractToolset[AgentDepsT]] - _enter_lock: Lock = field(compare=False, init=False, default_factory=Lock) _entered_count: int = field(init=False, default=0) _exit_stack: AsyncExitStack | None = field(init=False, default=None) + @functools.cached_property + def _enter_lock(self) -> anyio.Lock: + # We use a cached_property for this because it seems to work better with temporal... + return anyio.Lock() + @property def id(self) -> str | None: return None # pragma: no cover @@ -61,7 +66,7 @@ async def __aexit__(self, *args: Any) -> bool | None: self._exit_stack = None async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[AgentDepsT]]: - toolsets_tools = await asyncio.gather(*(toolset.get_tools(ctx) for toolset in self.toolsets)) + toolsets_tools = await _anyio.gather(*(toolset.get_tools(ctx) for toolset in self.toolsets)) all_tools: dict[str, ToolsetTool[AgentDepsT]] = {} for toolset, tools in zip(self.toolsets, toolsets_tools): diff --git a/pyproject.toml b/pyproject.toml index 204cee43ae..a6cffd9204 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -170,6 +170,8 @@ convention = "google" [tool.ruff.lint.flake8-tidy-imports.banned-api] "typing.TypedDict".msg = "Use typing_extensions.TypedDict instead." +"asyncio.gather".msg = "Use pydantic_ai._anyio.run_all instead." +"asyncio.create_task".msg = "Use `anyio.create_task_group` instead." [tool.ruff.format] # don't format python in docstrings, pytest-examples takes care of it