Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions pydantic_ai_slim/pydantic_ai/_run_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,19 @@
from .models import Model
from .result import RunUsage

# TODO (v2): Change the default for all typevars like this from `None` to `object`
AgentDepsT = TypeVar('AgentDepsT', default=None, contravariant=True)
"""Type variable for agent dependencies."""

RunContextAgentDepsT = TypeVar('RunContextAgentDepsT', default=None, covariant=True)
"""Type variable for the agent dependencies in `RunContext`."""


@dataclasses.dataclass(repr=False, kw_only=True)
class RunContext(Generic[AgentDepsT]):
class RunContext(Generic[RunContextAgentDepsT]):
"""Information about the current call."""

deps: AgentDepsT
deps: RunContextAgentDepsT
"""Dependencies for the agent."""
model: Model
"""The model used in this run."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,13 @@

from typing import Any

from typing_extensions import TypeVar

from pydantic_ai.exceptions import UserError
from pydantic_ai.tools import AgentDepsT, RunContext
from pydantic_ai.tools import RunContext

AgentDepsT = TypeVar('AgentDepsT', default=None, covariant=True)
"""Type variable for the agent dependencies in `RunContext`."""


class TemporalRunContext(RunContext[AgentDepsT]):
Expand Down Expand Up @@ -47,6 +52,6 @@ def serialize_run_context(cls, ctx: RunContext[Any]) -> dict[str, Any]:
}

@classmethod
def deserialize_run_context(cls, ctx: dict[str, Any], deps: AgentDepsT) -> TemporalRunContext[AgentDepsT]:
def deserialize_run_context(cls, ctx: dict[str, Any], deps: Any) -> TemporalRunContext[Any]:
"""Deserialize the run context from a `dict[str, Any]`."""
return cls(**ctx, deps=deps)
16 changes: 10 additions & 6 deletions pydantic_ai_slim/pydantic_ai/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,16 +240,20 @@ def _named_required_fields_schema(self, named_required_fields: Sequence[tuple[st
return s


ToolAgentDepsT = TypeVar('ToolAgentDepsT', default=object, contravariant=True)
"""Type variable for agent dependencies for a tool."""


@dataclass(init=False)
class Tool(Generic[AgentDepsT]):
class Tool(Generic[ToolAgentDepsT]):
"""A tool function for an agent."""

function: ToolFuncEither[AgentDepsT]
function: ToolFuncEither[ToolAgentDepsT]
takes_ctx: bool
max_retries: int | None
name: str
description: str | None
prepare: ToolPrepareFunc[AgentDepsT] | None
prepare: ToolPrepareFunc[ToolAgentDepsT] | None
docstring_format: DocstringFormat
require_parameter_descriptions: bool
strict: bool | None
Expand All @@ -265,13 +269,13 @@ class Tool(Generic[AgentDepsT]):

def __init__(
self,
function: ToolFuncEither[AgentDepsT],
function: ToolFuncEither[ToolAgentDepsT],
*,
takes_ctx: bool | None = None,
max_retries: int | None = None,
name: str | None = None,
description: str | None = None,
prepare: ToolPrepareFunc[AgentDepsT] | None = None,
prepare: ToolPrepareFunc[ToolAgentDepsT] | None = None,
docstring_format: DocstringFormat = 'auto',
require_parameter_descriptions: bool = False,
schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema,
Expand Down Expand Up @@ -413,7 +417,7 @@ def tool_def(self):
metadata=self.metadata,
)

async def prepare_tool_def(self, ctx: RunContext[AgentDepsT]) -> ToolDefinition | None:
async def prepare_tool_def(self, ctx: RunContext[ToolAgentDepsT]) -> ToolDefinition | None:
"""Get the tool definition.

By default, this method creates a tool definition, then either returns it, or calls `self.prepare`
Expand Down
2 changes: 1 addition & 1 deletion tests/typed_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ def my_method(self) -> bool:
Tool(foobar_ctx, takes_ctx=True)
Tool(foobar_ctx)
Tool(foobar_plain, takes_ctx=False)
assert_type(Tool(foobar_plain), Tool[None])
assert_type(Tool(foobar_plain), Tool[object])
assert_type(Tool(foobar_plain), Tool)

# unfortunately we can't type check these cases, since from a typing perspect `foobar_ctx` is valid as a plain tool
Expand Down
86 changes: 86 additions & 0 deletions tests/typed_deps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from dataclasses import dataclass
from typing import Any

from typing_extensions import assert_type

from pydantic_ai import Agent, RunContext, Tool, ToolDefinition


@dataclass
class DepsA:
a: int


@dataclass
class DepsB:
b: str


@dataclass
class AgentDeps(DepsA, DepsB):
pass


agent = Agent(
instructions='...',
model='...',
deps_type=AgentDeps,
)


@agent.tool
def tool_func_1(ctx: RunContext[DepsA]) -> int:
return ctx.deps.a


@agent.tool
def tool_func_2(ctx: RunContext[DepsB]) -> str:
return ctx.deps.b


# Ensure that you can use tools with deps that are supertypes of the agent's deps
agent.run_sync('...', deps=AgentDeps(a=0, b='test'))


def my_plain_tool() -> str:
return 'abc'


def my_context_tool(ctx: RunContext[int]) -> str:
return str(ctx.deps)


async def my_prepare_none(ctx: RunContext, tool_defn: ToolDefinition) -> None:
pass


async def my_prepare_object(ctx: RunContext[object], tool_defn: ToolDefinition) -> None:
pass


async def my_prepare_any(ctx: RunContext[Any], tool_defn: ToolDefinition) -> None:
pass


tool_1 = Tool(my_plain_tool)
assert_type(tool_1, Tool[object])

tool_2 = Tool(my_plain_tool, prepare=my_prepare_none)
assert_type(tool_2, Tool[None]) # due to default parameter of RunContext being None and inferring from prepare

tool_3 = Tool(my_plain_tool, prepare=my_prepare_object)
assert_type(tool_3, Tool[object])

tool_4 = Tool(my_plain_tool, prepare=my_prepare_any)
assert_type(tool_4, Tool[Any])

tool_5 = Tool(my_context_tool)
assert_type(tool_5, Tool[int])

tool_6 = Tool(my_context_tool, prepare=my_prepare_object)
assert_type(tool_6, Tool[int])

# Note: The following is not ideal behavior, but the workaround is to just not use Any as the argument to your prepare
# function, as shown in the example immediately above
tool_7 = Tool(my_context_tool, prepare=my_prepare_any)
assert_type(tool_7, Tool[Any])