Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion pydantic_ai_slim/pydantic_ai/_agent_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -915,8 +915,10 @@ async def handle_call_or_result(
yield event

else:
# TODO(Marcelo): We need to replace this with `anyio.as_completed()` from
# https://github.com/agronholm/anyio/pull/890.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You intend to vendor that function right? Because I don't think we should wait until that's released and then bump the requirement all the way to the latest -- that'd be annoying for users

tasks = [
asyncio.create_task(
asyncio.create_task( # noqa: TID251
_call_tool(tool_manager, call, tool_call_results.get(call.tool_call_id), usage_limits),
name=call.tool_name,
)
Expand Down
23 changes: 23 additions & 0 deletions pydantic_ai_slim/pydantic_ai/_anyio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from __future__ import annotations as _annotations

from collections.abc import Awaitable
from typing import TypeVar, cast

import anyio

T = TypeVar('T')


async def gather(*awaitables: Awaitable[T]) -> list[T]:
"""Run multiple awaitables concurrently using an AnyIO task group."""
# We initialize the list, so we can insert the results in the correct order.
results: list[T] = cast(list[T], [None] * len(awaitables))

async def run_and_store(coro: Awaitable[T], index: int) -> None:
results[index] = await coro

async with anyio.create_task_group() as tg:
for i, c in enumerate(awaitables):
tg.start_soon(run_and_store, c, i)

return results
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we return a tuple?

9 changes: 4 additions & 5 deletions pydantic_ai_slim/pydantic_ai/_cli.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
from __future__ import annotations as _annotations

import argparse
import asyncio
import importlib
import os
import sys
from asyncio import CancelledError
from collections.abc import Sequence
from contextlib import ExitStack
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, cast

import anyio
from typing_inspection.introspection import get_literal_values

from . import __version__
Expand Down Expand Up @@ -209,13 +208,13 @@ def cli( # noqa: C901

if prompt := cast(str, args.prompt):
try:
asyncio.run(ask_agent(agent, prompt, stream, console, code_theme))
anyio.run(ask_agent, agent, prompt, stream, console, code_theme)
except KeyboardInterrupt:
pass
return 0

try:
return asyncio.run(run_chat(stream, agent, console, code_theme, prog_name))
return anyio.run(run_chat, stream, agent, console, code_theme, prog_name)
except KeyboardInterrupt: # pragma: no cover
return 0

Expand Down Expand Up @@ -256,7 +255,7 @@ async def run_chat(
else:
try:
messages = await ask_agent(agent, text, stream, console, code_theme, deps, messages)
except CancelledError: # pragma: no cover
except anyio.get_cancelled_exc_class(): # pragma: no cover
console.print('[dim]Interrupted[/dim]')
except Exception as e: # pragma: no cover
cause = getattr(e, '__cause__', None)
Expand Down
12 changes: 4 additions & 8 deletions pydantic_ai_slim/pydantic_ai/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,11 @@
from types import GenericAlias
from typing import TYPE_CHECKING, Any, Generic, TypeAlias, TypeGuard, TypeVar, get_args, get_origin, overload

import anyio
from anyio.to_thread import run_sync
from pydantic import BaseModel, TypeAdapter
from pydantic.json_schema import JsonSchemaValue
from typing_extensions import (
ParamSpec,
TypeIs,
is_typeddict,
)
from typing_extensions import ParamSpec, TypeIs, is_typeddict
from typing_inspection import typing_objects
from typing_inspection.introspection import is_union_origin

Expand Down Expand Up @@ -97,8 +94,6 @@ class Some(Generic[T]):
class Unset:
"""A singleton to represent an unset value."""

pass


UNSET = Unset()

Expand Down Expand Up @@ -197,7 +192,8 @@ async def async_iter_groups() -> AsyncIterator[list[T]]:
group_start_time = None

try:
yield async_iter_groups()
async with anyio.create_task_group() as tg:
yield async_iter_groups()
finally: # pragma: no cover
# after iteration if a tasks still exists, cancel it, this will only happen if an error occurred
if task:
Expand Down
10 changes: 7 additions & 3 deletions pydantic_ai_slim/pydantic_ai/agent/__init__.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
from __future__ import annotations as _annotations

import dataclasses
import functools
import inspect
import json
import warnings
from asyncio import Lock
from collections.abc import AsyncIterator, Awaitable, Callable, Iterator, Sequence
from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager, contextmanager
from contextvars import ContextVar
from typing import TYPE_CHECKING, Any, ClassVar, cast, overload

import anyio
from opentelemetry.trace import NoOpTracer, use_span
from pydantic.json_schema import GenerateJsonSchema
from typing_extensions import Self, TypeVar, deprecated
Expand Down Expand Up @@ -153,10 +154,14 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]):

_event_stream_handler: EventStreamHandler[AgentDepsT] | None = dataclasses.field(repr=False)

_enter_lock: Lock = dataclasses.field(repr=False)
_entered_count: int = dataclasses.field(repr=False)
_exit_stack: AsyncExitStack | None = dataclasses.field(repr=False)

@functools.cached_property
def _enter_lock(self) -> anyio.Lock:
# We use a cached_property for this because it seems to work better with temporal...
return anyio.Lock()

@overload
def __init__(
self,
Expand Down Expand Up @@ -371,7 +376,6 @@ def __init__(
_utils.Option[Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]]]
] = ContextVar('_override_tools', default=None)

self._enter_lock = Lock()
self._entered_count = 0
self._exit_stack = None

Expand Down
9 changes: 5 additions & 4 deletions pydantic_ai_slim/pydantic_ai/mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,9 @@
import functools
import warnings
from abc import ABC, abstractmethod
from asyncio import Lock
from collections.abc import AsyncIterator, Awaitable, Callable, Sequence
from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager
from dataclasses import field, replace
from dataclasses import replace
from datetime import timedelta
from pathlib import Path
from typing import Annotated, Any
Expand Down Expand Up @@ -105,14 +104,17 @@ class MCPServer(AbstractToolset[Any], ABC):

_id: str | None

_enter_lock: Lock = field(compare=False)
_running_count: int
_exit_stack: AsyncExitStack | None

_client: ClientSession
_read_stream: MemoryObjectReceiveStream[SessionMessage | Exception]
_write_stream: MemoryObjectSendStream[SessionMessage]

@functools.cached_property
def _enter_lock(self) -> anyio.Lock:
return anyio.Lock()
Comment on lines +115 to +116
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def _enter_lock(self) -> anyio.Lock:
return anyio.Lock()
def _enter_lock(self) -> anyio.Lock:
# We use a cached_property for this because it seems to work better with temporal...
return anyio.Lock()


def __init__(
self,
tool_prefix: str | None = None,
Expand Down Expand Up @@ -144,7 +146,6 @@ def __init__(
self.__post_init__()

def __post_init__(self):
self._enter_lock = Lock()
self._running_count = 0
self._exit_stack = None

Expand Down
9 changes: 6 additions & 3 deletions pydantic_ai_slim/pydantic_ai/providers/google_vertex.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from __future__ import annotations as _annotations

import functools
from asyncio import Lock
from collections.abc import AsyncGenerator, Mapping
from pathlib import Path
from typing import Literal, overload

import anyio
import anyio.to_thread
import httpx
from typing_extensions import deprecated
Expand Down Expand Up @@ -119,10 +119,13 @@ def __init__(
class _VertexAIAuth(httpx.Auth):
"""Auth class for Vertex AI API."""

_refresh_lock: Lock = Lock()

credentials: BaseCredentials | ServiceAccountCredentials | None

@functools.cached_property
def _refresh_lock(self) -> anyio.Lock:
# We use a cached_property for this because it seems to work better with temporal...
return anyio.Lock()

def __init__(
self,
service_account_file: Path | str | None = None,
Expand Down
13 changes: 9 additions & 4 deletions pydantic_ai_slim/pydantic_ai/toolsets/combined.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
from __future__ import annotations

import asyncio
from asyncio import Lock
import functools
from collections.abc import Callable, Sequence
from contextlib import AsyncExitStack
from dataclasses import dataclass, field, replace
from typing import Any

import anyio
from typing_extensions import Self

from .. import _anyio
from .._run_context import AgentDepsT, RunContext
from ..exceptions import UserError
from .abstract import AbstractToolset, ToolsetTool
Expand All @@ -31,10 +32,14 @@ class CombinedToolset(AbstractToolset[AgentDepsT]):

toolsets: Sequence[AbstractToolset[AgentDepsT]]

_enter_lock: Lock = field(compare=False, init=False, default_factory=Lock)
_entered_count: int = field(init=False, default=0)
_exit_stack: AsyncExitStack | None = field(init=False, default=None)

@functools.cached_property
def _enter_lock(self) -> anyio.Lock:
# We use a cached_property for this because it seems to work better with temporal...
return anyio.Lock()

@property
def id(self) -> str | None:
return None # pragma: no cover
Expand All @@ -61,7 +66,7 @@ async def __aexit__(self, *args: Any) -> bool | None:
self._exit_stack = None

async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[AgentDepsT]]:
toolsets_tools = await asyncio.gather(*(toolset.get_tools(ctx) for toolset in self.toolsets))
toolsets_tools = await _anyio.gather(*(toolset.get_tools(ctx) for toolset in self.toolsets))
all_tools: dict[str, ToolsetTool[AgentDepsT]] = {}

for toolset, tools in zip(self.toolsets, toolsets_tools):
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,8 @@ convention = "google"

[tool.ruff.lint.flake8-tidy-imports.banned-api]
"typing.TypedDict".msg = "Use typing_extensions.TypedDict instead."
"asyncio.gather".msg = "Use pydantic_ai._anyio.run_all instead."
"asyncio.create_task".msg = "Use `anyio.create_task_group` instead."

[tool.ruff.format]
# don't format python in docstrings, pytest-examples takes care of it
Expand Down
Loading