Skip to content

Commit d16f634

Browse files
dmontaguDouweM
andauthored
Fix typevar variance for agent deps (#3319)
Co-authored-by: Douwe Maan <[email protected]>
1 parent d96bef8 commit d16f634

File tree

5 files changed

+110
-11
lines changed

5 files changed

+110
-11
lines changed

pydantic_ai_slim/pydantic_ai/_run_context.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,19 @@
1616
from .models import Model
1717
from .result import RunUsage
1818

19+
# TODO (v2): Change the default for all typevars like this from `None` to `object`
1920
AgentDepsT = TypeVar('AgentDepsT', default=None, contravariant=True)
2021
"""Type variable for agent dependencies."""
2122

23+
RunContextAgentDepsT = TypeVar('RunContextAgentDepsT', default=None, covariant=True)
24+
"""Type variable for the agent dependencies in `RunContext`."""
25+
2226

2327
@dataclasses.dataclass(repr=False, kw_only=True)
24-
class RunContext(Generic[AgentDepsT]):
28+
class RunContext(Generic[RunContextAgentDepsT]):
2529
"""Information about the current call."""
2630

27-
deps: AgentDepsT
31+
deps: RunContextAgentDepsT
2832
"""Dependencies for the agent."""
2933
model: Model
3034
"""The model used in this run."""

pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_run_context.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,13 @@
22

33
from typing import Any
44

5+
from typing_extensions import TypeVar
6+
57
from pydantic_ai.exceptions import UserError
6-
from pydantic_ai.tools import AgentDepsT, RunContext
8+
from pydantic_ai.tools import RunContext
9+
10+
AgentDepsT = TypeVar('AgentDepsT', default=None, covariant=True)
11+
"""Type variable for the agent dependencies in `RunContext`."""
712

813

914
class TemporalRunContext(RunContext[AgentDepsT]):
@@ -47,6 +52,6 @@ def serialize_run_context(cls, ctx: RunContext[Any]) -> dict[str, Any]:
4752
}
4853

4954
@classmethod
50-
def deserialize_run_context(cls, ctx: dict[str, Any], deps: AgentDepsT) -> TemporalRunContext[AgentDepsT]:
55+
def deserialize_run_context(cls, ctx: dict[str, Any], deps: Any) -> TemporalRunContext[Any]:
5156
"""Deserialize the run context from a `dict[str, Any]`."""
5257
return cls(**ctx, deps=deps)

pydantic_ai_slim/pydantic_ai/tools.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -240,16 +240,20 @@ def _named_required_fields_schema(self, named_required_fields: Sequence[tuple[st
240240
return s
241241

242242

243+
ToolAgentDepsT = TypeVar('ToolAgentDepsT', default=object, contravariant=True)
244+
"""Type variable for agent dependencies for a tool."""
245+
246+
243247
@dataclass(init=False)
244-
class Tool(Generic[AgentDepsT]):
248+
class Tool(Generic[ToolAgentDepsT]):
245249
"""A tool function for an agent."""
246250

247-
function: ToolFuncEither[AgentDepsT]
251+
function: ToolFuncEither[ToolAgentDepsT]
248252
takes_ctx: bool
249253
max_retries: int | None
250254
name: str
251255
description: str | None
252-
prepare: ToolPrepareFunc[AgentDepsT] | None
256+
prepare: ToolPrepareFunc[ToolAgentDepsT] | None
253257
docstring_format: DocstringFormat
254258
require_parameter_descriptions: bool
255259
strict: bool | None
@@ -265,13 +269,13 @@ class Tool(Generic[AgentDepsT]):
265269

266270
def __init__(
267271
self,
268-
function: ToolFuncEither[AgentDepsT],
272+
function: ToolFuncEither[ToolAgentDepsT],
269273
*,
270274
takes_ctx: bool | None = None,
271275
max_retries: int | None = None,
272276
name: str | None = None,
273277
description: str | None = None,
274-
prepare: ToolPrepareFunc[AgentDepsT] | None = None,
278+
prepare: ToolPrepareFunc[ToolAgentDepsT] | None = None,
275279
docstring_format: DocstringFormat = 'auto',
276280
require_parameter_descriptions: bool = False,
277281
schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema,
@@ -413,7 +417,7 @@ def tool_def(self):
413417
metadata=self.metadata,
414418
)
415419

416-
async def prepare_tool_def(self, ctx: RunContext[AgentDepsT]) -> ToolDefinition | None:
420+
async def prepare_tool_def(self, ctx: RunContext[ToolAgentDepsT]) -> ToolDefinition | None:
417421
"""Get the tool definition.
418422
419423
By default, this method creates a tool definition, then either returns it, or calls `self.prepare`

tests/typed_agent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,7 @@ def my_method(self) -> bool:
260260
Tool(foobar_ctx, takes_ctx=True)
261261
Tool(foobar_ctx)
262262
Tool(foobar_plain, takes_ctx=False)
263-
assert_type(Tool(foobar_plain), Tool[None])
263+
assert_type(Tool(foobar_plain), Tool[object])
264264
assert_type(Tool(foobar_plain), Tool)
265265

266266
# unfortunately we can't type check these cases, since from a typing perspect `foobar_ctx` is valid as a plain tool

tests/typed_deps.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
from dataclasses import dataclass
2+
from typing import Any
3+
4+
from typing_extensions import assert_type
5+
6+
from pydantic_ai import Agent, RunContext, Tool, ToolDefinition
7+
8+
9+
@dataclass
10+
class DepsA:
11+
a: int
12+
13+
14+
@dataclass
15+
class DepsB:
16+
b: str
17+
18+
19+
@dataclass
20+
class AgentDeps(DepsA, DepsB):
21+
pass
22+
23+
24+
agent = Agent(
25+
instructions='...',
26+
model='...',
27+
deps_type=AgentDeps,
28+
)
29+
30+
31+
@agent.tool
32+
def tool_func_1(ctx: RunContext[DepsA]) -> int:
33+
return ctx.deps.a
34+
35+
36+
@agent.tool
37+
def tool_func_2(ctx: RunContext[DepsB]) -> str:
38+
return ctx.deps.b
39+
40+
41+
# Ensure that you can use tools with deps that are supertypes of the agent's deps
42+
agent.run_sync('...', deps=AgentDeps(a=0, b='test'))
43+
44+
45+
def my_plain_tool() -> str:
46+
return 'abc'
47+
48+
49+
def my_context_tool(ctx: RunContext[int]) -> str:
50+
return str(ctx.deps)
51+
52+
53+
async def my_prepare_none(ctx: RunContext, tool_defn: ToolDefinition) -> None:
54+
pass
55+
56+
57+
async def my_prepare_object(ctx: RunContext[object], tool_defn: ToolDefinition) -> None:
58+
pass
59+
60+
61+
async def my_prepare_any(ctx: RunContext[Any], tool_defn: ToolDefinition) -> None:
62+
pass
63+
64+
65+
tool_1 = Tool(my_plain_tool)
66+
assert_type(tool_1, Tool[object])
67+
68+
tool_2 = Tool(my_plain_tool, prepare=my_prepare_none)
69+
assert_type(tool_2, Tool[None]) # due to default parameter of RunContext being None and inferring from prepare
70+
71+
tool_3 = Tool(my_plain_tool, prepare=my_prepare_object)
72+
assert_type(tool_3, Tool[object])
73+
74+
tool_4 = Tool(my_plain_tool, prepare=my_prepare_any)
75+
assert_type(tool_4, Tool[Any])
76+
77+
tool_5 = Tool(my_context_tool)
78+
assert_type(tool_5, Tool[int])
79+
80+
tool_6 = Tool(my_context_tool, prepare=my_prepare_object)
81+
assert_type(tool_6, Tool[int])
82+
83+
# Note: The following is not ideal behavior, but the workaround is to just not use Any as the argument to your prepare
84+
# function, as shown in the example immediately above
85+
tool_7 = Tool(my_context_tool, prepare=my_prepare_any)
86+
assert_type(tool_7, Tool[Any])

0 commit comments

Comments
 (0)