Skip to content

Commit 8cd58ec

Browse files
AgentDeps default to None. (#592)
Co-authored-by: David Montague <[email protected]>
1 parent 1db2772 commit 8cd58ec

File tree

2 files changed

+14
-9
lines changed

2 files changed

+14
-9
lines changed

pydantic_ai_slim/pydantic_ai/tools.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@
44
import inspect
55
from collections.abc import Awaitable
66
from dataclasses import dataclass, field
7-
from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, Union, cast
7+
from typing import TYPE_CHECKING, Any, Callable, Generic, Union, cast
88

99
from pydantic import ValidationError
1010
from pydantic_core import SchemaValidator
11-
from typing_extensions import Concatenate, ParamSpec, TypeAlias
11+
from typing_extensions import Concatenate, ParamSpec, TypeAlias, TypeVar
1212

1313
from . import _pydantic, _utils, messages as _messages, models
1414
from .exceptions import ModelRetry, UnexpectedModelBehavior
@@ -30,7 +30,7 @@
3030
'ToolDefinition',
3131
)
3232

33-
AgentDeps = TypeVar('AgentDeps')
33+
AgentDeps = TypeVar('AgentDeps', default=None)
3434
"""Type variable for agent dependencies."""
3535

3636

@@ -67,7 +67,7 @@ def replace_with(
6767
return dataclasses.replace(self, **kwargs)
6868

6969

70-
ToolParams = ParamSpec('ToolParams')
70+
ToolParams = ParamSpec('ToolParams', default=...)
7171
"""Retrieval function param spec."""
7272

7373
SystemPromptFunc = Union[
@@ -92,7 +92,7 @@ def replace_with(
9292
Usage `ToolPlainFunc[ToolParams]`.
9393
"""
9494
ToolFuncEither = Union[ToolFuncContext[AgentDeps, ToolParams], ToolFuncPlain[ToolParams]]
95-
"""Either part_kind of tool function.
95+
"""Either kind of tool function.
9696
9797
This is just a union of [`ToolFuncContext`][pydantic_ai.tools.ToolFuncContext] and
9898
[`ToolFuncPlain`][pydantic_ai.tools.ToolFuncPlain].
@@ -134,7 +134,7 @@ def hitchhiker(ctx: RunContext[int], answer: str) -> str:
134134
class Tool(Generic[AgentDeps]):
135135
"""A tool function for an agent."""
136136

137-
function: ToolFuncEither[AgentDeps, ...]
137+
function: ToolFuncEither[AgentDeps]
138138
takes_ctx: bool
139139
max_retries: int | None
140140
name: str
@@ -150,7 +150,7 @@ class Tool(Generic[AgentDeps]):
150150

151151
def __init__(
152152
self,
153-
function: ToolFuncEither[AgentDeps, ...],
153+
function: ToolFuncEither[AgentDeps],
154154
*,
155155
takes_ctx: bool | None = None,
156156
max_retries: int | None = None,

tests/typed_agent.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,8 @@ def foobar_plain(x: str, y: int) -> str:
196196
Tool(foobar_ctx, takes_ctx=True)
197197
Tool(foobar_ctx)
198198
Tool(foobar_plain, takes_ctx=False)
199-
Tool(foobar_plain)
199+
assert_type(Tool(foobar_plain), Tool[None])
200+
assert_type(Tool(foobar_plain), Tool)
200201

201202
# unfortunately we can't type check these cases, since from a typing perspect `foobar_ctx` is valid as a plain tool
202203
Tool(foobar_ctx, takes_ctx=False)
@@ -206,12 +207,15 @@ def foobar_plain(x: str, y: int) -> str:
206207
Agent('test', tools=[foobar_plain], deps_type=int)
207208
Agent('test', tools=[foobar_plain])
208209
Agent('test', tools=[Tool(foobar_ctx)], deps_type=int)
209-
Agent('test', tools=[Tool(foobar_plain)], deps_type=int)
210210
Agent('test', tools=[Tool(foobar_ctx), foobar_ctx, foobar_plain], deps_type=int)
211+
Agent('test', tools=[Tool(foobar_ctx), foobar_ctx, Tool(foobar_plain)], deps_type=int)
211212

212213
Agent('test', tools=[foobar_ctx], deps_type=str) # pyright: ignore[reportArgumentType]
214+
Agent('test', tools=[Tool(foobar_ctx), Tool(foobar_plain)], deps_type=str) # pyright: ignore[reportArgumentType]
213215
Agent('test', tools=[foobar_ctx]) # pyright: ignore[reportArgumentType]
214216
Agent('test', tools=[Tool(foobar_ctx)]) # pyright: ignore[reportArgumentType]
217+
# since deps are not set, they default to `None`, so can't be `int`
218+
Agent('test', tools=[Tool(foobar_plain)], deps_type=int) # pyright: ignore[reportArgumentType]
215219

216220
# prepare example from docs:
217221

@@ -238,6 +242,7 @@ async def prepare_greet(ctx: RunContext[str], tool_def: ToolDefinition) -> ToolD
238242
default_agent = Agent()
239243
assert_type(default_agent, Agent[None, str])
240244
assert_type(default_agent, Agent[None])
245+
assert_type(default_agent, Agent)
241246

242247
partial_agent: Agent[MyDeps] = Agent(deps_type=MyDeps)
243248
assert_type(partial_agent, Agent[MyDeps, str])

0 commit comments

Comments
 (0)