Skip to content

Commit 7ac80a1

Browse files
committed
More typing improvements..?
1 parent 5bc3cc2 commit 7ac80a1

File tree

5 files changed

+65
-15
lines changed

5 files changed

+65
-15
lines changed

pydantic_ai_slim/pydantic_ai/_run_context.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from .models import Model
1717
from .result import RunUsage
1818

19+
# TODO (v2): Change the default for all typevars like this from `None` to `object`
1920
AgentDepsT = TypeVar('AgentDepsT', default=None, contravariant=True)
2021
"""Type variable for agent dependencies."""
2122

pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_run_context.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,6 @@
1010
AgentDepsT = TypeVar('AgentDepsT', default=None, covariant=True)
1111
"""Type variable for the agent dependencies in `RunContext`."""
1212

13-
T = TypeVar('T', default=None)
14-
1513

1614
class TemporalRunContext(RunContext[AgentDepsT]):
1715
"""The [`RunContext`][pydantic_ai.tools.RunContext] subclass to use to serialize and deserialize the run context for use inside a Temporal activity.
@@ -53,7 +51,7 @@ def serialize_run_context(cls, ctx: RunContext[Any]) -> dict[str, Any]:
5351
'run_step': ctx.run_step,
5452
}
5553

56-
@staticmethod
57-
def deserialize_run_context(ctx: dict[str, Any], deps: T) -> TemporalRunContext[T]:
54+
@classmethod
55+
def deserialize_run_context(cls, ctx: dict[str, Any], deps: Any) -> TemporalRunContext[Any]:
5856
"""Deserialize the run context from a `dict[str, Any]`."""
59-
return TemporalRunContext(**ctx, deps=deps)
57+
return cls(**ctx, deps=deps)

pydantic_ai_slim/pydantic_ai/tools.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -240,16 +240,20 @@ def _named_required_fields_schema(self, named_required_fields: Sequence[tuple[st
240240
return s
241241

242242

243+
ToolAgentDepsT = TypeVar('ToolAgentDepsT', default=object, contravariant=True)
244+
"""Type variable for agent dependencies for a tool."""
245+
246+
243247
@dataclass(init=False)
244-
class Tool(Generic[AgentDepsT]):
248+
class Tool(Generic[ToolAgentDepsT]):
245249
"""A tool function for an agent."""
246250

247-
function: ToolFuncEither[AgentDepsT]
251+
function: ToolFuncEither[ToolAgentDepsT]
248252
takes_ctx: bool
249253
max_retries: int | None
250254
name: str
251255
description: str | None
252-
prepare: ToolPrepareFunc[AgentDepsT] | None
256+
prepare: ToolPrepareFunc[ToolAgentDepsT] | None
253257
docstring_format: DocstringFormat
254258
require_parameter_descriptions: bool
255259
strict: bool | None
@@ -265,13 +269,13 @@ class Tool(Generic[AgentDepsT]):
265269

266270
def __init__(
267271
self,
268-
function: ToolFuncEither[AgentDepsT],
272+
function: ToolFuncEither[ToolAgentDepsT],
269273
*,
270274
takes_ctx: bool | None = None,
271275
max_retries: int | None = None,
272276
name: str | None = None,
273277
description: str | None = None,
274-
prepare: ToolPrepareFunc[AgentDepsT] | None = None,
278+
prepare: ToolPrepareFunc[ToolAgentDepsT] | None = None,
275279
docstring_format: DocstringFormat = 'auto',
276280
require_parameter_descriptions: bool = False,
277281
schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema,
@@ -413,7 +417,7 @@ def tool_def(self):
413417
metadata=self.metadata,
414418
)
415419

416-
async def prepare_tool_def(self, ctx: RunContext[AgentDepsT]) -> ToolDefinition | None:
420+
async def prepare_tool_def(self, ctx: RunContext[ToolAgentDepsT]) -> ToolDefinition | None:
417421
"""Get the tool definition.
418422
419423
By default, this method creates a tool definition, then either returns it, or calls `self.prepare`

tests/typed_agent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,7 @@ def my_method(self) -> bool:
260260
Tool(foobar_ctx, takes_ctx=True)
261261
Tool(foobar_ctx)
262262
Tool(foobar_plain, takes_ctx=False)
263-
assert_type(Tool(foobar_plain), Tool[None])
263+
assert_type(Tool(foobar_plain), Tool[object])
264264
assert_type(Tool(foobar_plain), Tool)
265265

266266
# unfortunately we can't type check these cases, since from a typing perspect `foobar_ctx` is valid as a plain tool

tests/typed_deps.py

Lines changed: 50 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
from dataclasses import dataclass
2+
from typing import Any
23

3-
from pydantic_ai import Agent, RunContext
4+
from typing_extensions import assert_type
5+
6+
from pydantic_ai import Agent, RunContext, Tool, ToolDefinition
47

58

69
@dataclass
@@ -26,14 +29,58 @@ class AgentDeps(DepsA, DepsB):
2629

2730

2831
@agent.tool
29-
def tool_1(ctx: RunContext[DepsA]) -> int:
32+
def tool_func_1(ctx: RunContext[DepsA]) -> int:
3033
return ctx.deps.a
3134

3235

3336
@agent.tool
34-
def tool_2(ctx: RunContext[DepsB]) -> str:
37+
def tool_func_2(ctx: RunContext[DepsB]) -> str:
3538
return ctx.deps.b
3639

3740

3841
# Ensure that you can use tools with deps that are supertypes of the agent's deps
3942
agent.run_sync('...', deps=AgentDeps(a=0, b='test'))
43+
44+
45+
def my_plain_tool() -> str:
46+
return 'abc'
47+
48+
49+
def my_context_tool(ctx: RunContext[int]) -> str:
50+
return str(ctx.deps)
51+
52+
53+
async def my_prepare_none(ctx: RunContext, tool_defn: ToolDefinition) -> None:
54+
pass
55+
56+
57+
async def my_prepare_object(ctx: RunContext[object], tool_defn: ToolDefinition) -> None:
58+
pass
59+
60+
61+
async def my_prepare_any(ctx: RunContext[Any], tool_defn: ToolDefinition) -> None:
62+
pass
63+
64+
65+
tool_1 = Tool(my_plain_tool)
66+
assert_type(tool_1, Tool[object])
67+
68+
tool_2 = Tool(my_plain_tool, prepare=my_prepare_none)
69+
assert_type(tool_2, Tool[None]) # due to default parameter of RunContext being None and inferring from prepare
70+
71+
tool_3 = Tool(my_plain_tool, prepare=my_prepare_object)
72+
assert_type(tool_3, Tool[object])
73+
74+
tool_4 = Tool(my_plain_tool, prepare=my_prepare_any)
75+
assert_type(tool_4, Tool[Any])
76+
77+
tool_5 = Tool(my_context_tool)
78+
assert_type(tool_5, Tool[int])
79+
80+
tool_6 = Tool(my_context_tool, prepare=my_prepare_object)
81+
assert_type(tool_6, Tool[int])
82+
83+
# Note: The following is not ideal behavior, but the workaround is to just not use Any as the argument to your prepare
84+
# function, as shown in the example immediately above
85+
tool_7 = Tool(my_context_tool, prepare=my_prepare_any)
86+
assert_type(tool_7, Tool[Any])

0 commit comments

Comments
 (0)