From 5bc3cc26e9cdd86bac959d6c894685ad67d8028e Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Mon, 3 Nov 2025 11:04:11 -0700 Subject: [PATCH 1/2] Fix typevar variance for agent deps --- pydantic_ai_slim/pydantic_ai/_run_context.py | 7 +++- .../durable_exec/temporal/_run_context.py | 15 +++++-- tests/typed_deps.py | 39 +++++++++++++++++++ 3 files changed, 55 insertions(+), 6 deletions(-) create mode 100644 tests/typed_deps.py diff --git a/pydantic_ai_slim/pydantic_ai/_run_context.py b/pydantic_ai_slim/pydantic_ai/_run_context.py index df2a4c1b5a..2452253032 100644 --- a/pydantic_ai_slim/pydantic_ai/_run_context.py +++ b/pydantic_ai_slim/pydantic_ai/_run_context.py @@ -19,12 +19,15 @@ AgentDepsT = TypeVar('AgentDepsT', default=None, contravariant=True) """Type variable for agent dependencies.""" +RunContextAgentDepsT = TypeVar('RunContextAgentDepsT', default=None, covariant=True) +"""Type variable for the agent dependencies in `RunContext`.""" + @dataclasses.dataclass(repr=False, kw_only=True) -class RunContext(Generic[AgentDepsT]): +class RunContext(Generic[RunContextAgentDepsT]): """Information about the current call.""" - deps: AgentDepsT + deps: RunContextAgentDepsT """Dependencies for the agent.""" model: Model """The model used in this run.""" diff --git a/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_run_context.py b/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_run_context.py index 15c4e33de0..3d769ba4a2 100644 --- a/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_run_context.py +++ b/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_run_context.py @@ -2,8 +2,15 @@ from typing import Any +from typing_extensions import TypeVar + from pydantic_ai.exceptions import UserError -from pydantic_ai.tools import AgentDepsT, RunContext +from pydantic_ai.tools import RunContext + +AgentDepsT = TypeVar('AgentDepsT', default=None, covariant=True) +"""Type variable for the agent dependencies in `RunContext`.""" + +T = TypeVar('T', default=None) class TemporalRunContext(RunContext[AgentDepsT]): @@ -46,7 +53,7 @@ def serialize_run_context(cls, ctx: RunContext[Any]) -> dict[str, Any]: 'run_step': ctx.run_step, } - @classmethod - def deserialize_run_context(cls, ctx: dict[str, Any], deps: AgentDepsT) -> TemporalRunContext[AgentDepsT]: + @staticmethod + def deserialize_run_context(ctx: dict[str, Any], deps: T) -> TemporalRunContext[T]: """Deserialize the run context from a `dict[str, Any]`.""" - return cls(**ctx, deps=deps) + return TemporalRunContext(**ctx, deps=deps) diff --git a/tests/typed_deps.py b/tests/typed_deps.py new file mode 100644 index 0000000000..64d0ec4199 --- /dev/null +++ b/tests/typed_deps.py @@ -0,0 +1,39 @@ +from dataclasses import dataclass + +from pydantic_ai import Agent, RunContext + + +@dataclass +class DepsA: + a: int + + +@dataclass +class DepsB: + b: str + + +@dataclass +class AgentDeps(DepsA, DepsB): + pass + + +agent = Agent( + instructions='...', + model='...', + deps_type=AgentDeps, +) + + +@agent.tool +def tool_1(ctx: RunContext[DepsA]) -> int: + return ctx.deps.a + + +@agent.tool +def tool_2(ctx: RunContext[DepsB]) -> str: + return ctx.deps.b + + +# Ensure that you can use tools with deps that are supertypes of the agent's deps +agent.run_sync('...', deps=AgentDeps(a=0, b='test')) From 7ac80a1ecd5352c57f9d6d837399eb9fbcb0b41c Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Mon, 3 Nov 2025 13:37:17 -0700 Subject: [PATCH 2/2] More typing improvements..? --- pydantic_ai_slim/pydantic_ai/_run_context.py | 1 + .../durable_exec/temporal/_run_context.py | 8 ++- pydantic_ai_slim/pydantic_ai/tools.py | 16 +++--- tests/typed_agent.py | 2 +- tests/typed_deps.py | 53 +++++++++++++++++-- 5 files changed, 65 insertions(+), 15 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/_run_context.py b/pydantic_ai_slim/pydantic_ai/_run_context.py index 2452253032..e17afd78a8 100644 --- a/pydantic_ai_slim/pydantic_ai/_run_context.py +++ b/pydantic_ai_slim/pydantic_ai/_run_context.py @@ -16,6 +16,7 @@ from .models import Model from .result import RunUsage +# TODO (v2): Change the default for all typevars like this from `None` to `object` AgentDepsT = TypeVar('AgentDepsT', default=None, contravariant=True) """Type variable for agent dependencies.""" diff --git a/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_run_context.py b/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_run_context.py index 3d769ba4a2..fa307dd68b 100644 --- a/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_run_context.py +++ b/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_run_context.py @@ -10,8 +10,6 @@ AgentDepsT = TypeVar('AgentDepsT', default=None, covariant=True) """Type variable for the agent dependencies in `RunContext`.""" -T = TypeVar('T', default=None) - class TemporalRunContext(RunContext[AgentDepsT]): """The [`RunContext`][pydantic_ai.tools.RunContext] subclass to use to serialize and deserialize the run context for use inside a Temporal activity. @@ -53,7 +51,7 @@ def serialize_run_context(cls, ctx: RunContext[Any]) -> dict[str, Any]: 'run_step': ctx.run_step, } - @staticmethod - def deserialize_run_context(ctx: dict[str, Any], deps: T) -> TemporalRunContext[T]: + @classmethod + def deserialize_run_context(cls, ctx: dict[str, Any], deps: Any) -> TemporalRunContext[Any]: """Deserialize the run context from a `dict[str, Any]`.""" - return TemporalRunContext(**ctx, deps=deps) + return cls(**ctx, deps=deps) diff --git a/pydantic_ai_slim/pydantic_ai/tools.py b/pydantic_ai_slim/pydantic_ai/tools.py index 844e99a25e..b3b1fc2324 100644 --- a/pydantic_ai_slim/pydantic_ai/tools.py +++ b/pydantic_ai_slim/pydantic_ai/tools.py @@ -240,16 +240,20 @@ def _named_required_fields_schema(self, named_required_fields: Sequence[tuple[st return s +ToolAgentDepsT = TypeVar('ToolAgentDepsT', default=object, contravariant=True) +"""Type variable for agent dependencies for a tool.""" + + @dataclass(init=False) -class Tool(Generic[AgentDepsT]): +class Tool(Generic[ToolAgentDepsT]): """A tool function for an agent.""" - function: ToolFuncEither[AgentDepsT] + function: ToolFuncEither[ToolAgentDepsT] takes_ctx: bool max_retries: int | None name: str description: str | None - prepare: ToolPrepareFunc[AgentDepsT] | None + prepare: ToolPrepareFunc[ToolAgentDepsT] | None docstring_format: DocstringFormat require_parameter_descriptions: bool strict: bool | None @@ -265,13 +269,13 @@ class Tool(Generic[AgentDepsT]): def __init__( self, - function: ToolFuncEither[AgentDepsT], + function: ToolFuncEither[ToolAgentDepsT], *, takes_ctx: bool | None = None, max_retries: int | None = None, name: str | None = None, description: str | None = None, - prepare: ToolPrepareFunc[AgentDepsT] | None = None, + prepare: ToolPrepareFunc[ToolAgentDepsT] | None = None, docstring_format: DocstringFormat = 'auto', require_parameter_descriptions: bool = False, schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema, @@ -413,7 +417,7 @@ def tool_def(self): metadata=self.metadata, ) - async def prepare_tool_def(self, ctx: RunContext[AgentDepsT]) -> ToolDefinition | None: + async def prepare_tool_def(self, ctx: RunContext[ToolAgentDepsT]) -> ToolDefinition | None: """Get the tool definition. By default, this method creates a tool definition, then either returns it, or calls `self.prepare` diff --git a/tests/typed_agent.py b/tests/typed_agent.py index 1dbc2cb050..83c8f7bc3f 100644 --- a/tests/typed_agent.py +++ b/tests/typed_agent.py @@ -260,7 +260,7 @@ def my_method(self) -> bool: Tool(foobar_ctx, takes_ctx=True) Tool(foobar_ctx) Tool(foobar_plain, takes_ctx=False) -assert_type(Tool(foobar_plain), Tool[None]) +assert_type(Tool(foobar_plain), Tool[object]) assert_type(Tool(foobar_plain), Tool) # unfortunately we can't type check these cases, since from a typing perspect `foobar_ctx` is valid as a plain tool diff --git a/tests/typed_deps.py b/tests/typed_deps.py index 64d0ec4199..20f3d47b9e 100644 --- a/tests/typed_deps.py +++ b/tests/typed_deps.py @@ -1,6 +1,9 @@ from dataclasses import dataclass +from typing import Any -from pydantic_ai import Agent, RunContext +from typing_extensions import assert_type + +from pydantic_ai import Agent, RunContext, Tool, ToolDefinition @dataclass @@ -26,14 +29,58 @@ class AgentDeps(DepsA, DepsB): @agent.tool -def tool_1(ctx: RunContext[DepsA]) -> int: +def tool_func_1(ctx: RunContext[DepsA]) -> int: return ctx.deps.a @agent.tool -def tool_2(ctx: RunContext[DepsB]) -> str: +def tool_func_2(ctx: RunContext[DepsB]) -> str: return ctx.deps.b # Ensure that you can use tools with deps that are supertypes of the agent's deps agent.run_sync('...', deps=AgentDeps(a=0, b='test')) + + +def my_plain_tool() -> str: + return 'abc' + + +def my_context_tool(ctx: RunContext[int]) -> str: + return str(ctx.deps) + + +async def my_prepare_none(ctx: RunContext, tool_defn: ToolDefinition) -> None: + pass + + +async def my_prepare_object(ctx: RunContext[object], tool_defn: ToolDefinition) -> None: + pass + + +async def my_prepare_any(ctx: RunContext[Any], tool_defn: ToolDefinition) -> None: + pass + + +tool_1 = Tool(my_plain_tool) +assert_type(tool_1, Tool[object]) + +tool_2 = Tool(my_plain_tool, prepare=my_prepare_none) +assert_type(tool_2, Tool[None]) # due to default parameter of RunContext being None and inferring from prepare + +tool_3 = Tool(my_plain_tool, prepare=my_prepare_object) +assert_type(tool_3, Tool[object]) + +tool_4 = Tool(my_plain_tool, prepare=my_prepare_any) +assert_type(tool_4, Tool[Any]) + +tool_5 = Tool(my_context_tool) +assert_type(tool_5, Tool[int]) + +tool_6 = Tool(my_context_tool, prepare=my_prepare_object) +assert_type(tool_6, Tool[int]) + +# Note: The following is not ideal behavior, but the workaround is to just not use Any as the argument to your prepare +# function, as shown in the example immediately above +tool_7 = Tool(my_context_tool, prepare=my_prepare_any) +assert_type(tool_7, Tool[Any])