Skip to content

Commit 5bc3cc2

Browse files
committed
Fix typevar variance for agent deps
1 parent ab70709 commit 5bc3cc2

File tree

3 files changed

+55
-6
lines changed

3 files changed

+55
-6
lines changed

pydantic_ai_slim/pydantic_ai/_run_context.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,15 @@
1919
AgentDepsT = TypeVar('AgentDepsT', default=None, contravariant=True)
2020
"""Type variable for agent dependencies."""
2121

22+
RunContextAgentDepsT = TypeVar('RunContextAgentDepsT', default=None, covariant=True)
23+
"""Type variable for the agent dependencies in `RunContext`."""
24+
2225

2326
@dataclasses.dataclass(repr=False, kw_only=True)
24-
class RunContext(Generic[AgentDepsT]):
27+
class RunContext(Generic[RunContextAgentDepsT]):
2528
"""Information about the current call."""
2629

27-
deps: AgentDepsT
30+
deps: RunContextAgentDepsT
2831
"""Dependencies for the agent."""
2932
model: Model
3033
"""The model used in this run."""

pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_run_context.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,15 @@
22

33
from typing import Any
44

5+
from typing_extensions import TypeVar
6+
57
from pydantic_ai.exceptions import UserError
6-
from pydantic_ai.tools import AgentDepsT, RunContext
8+
from pydantic_ai.tools import RunContext
9+
10+
AgentDepsT = TypeVar('AgentDepsT', default=None, covariant=True)
11+
"""Type variable for the agent dependencies in `RunContext`."""
12+
13+
T = TypeVar('T', default=None)
714

815

916
class TemporalRunContext(RunContext[AgentDepsT]):
@@ -46,7 +53,7 @@ def serialize_run_context(cls, ctx: RunContext[Any]) -> dict[str, Any]:
4653
'run_step': ctx.run_step,
4754
}
4855

49-
@classmethod
50-
def deserialize_run_context(cls, ctx: dict[str, Any], deps: AgentDepsT) -> TemporalRunContext[AgentDepsT]:
56+
@staticmethod
57+
def deserialize_run_context(ctx: dict[str, Any], deps: T) -> TemporalRunContext[T]:
5158
"""Deserialize the run context from a `dict[str, Any]`."""
52-
return cls(**ctx, deps=deps)
59+
return TemporalRunContext(**ctx, deps=deps)

tests/typed_deps.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
from dataclasses import dataclass
2+
3+
from pydantic_ai import Agent, RunContext
4+
5+
6+
@dataclass
7+
class DepsA:
8+
a: int
9+
10+
11+
@dataclass
12+
class DepsB:
13+
b: str
14+
15+
16+
@dataclass
17+
class AgentDeps(DepsA, DepsB):
18+
pass
19+
20+
21+
agent = Agent(
22+
instructions='...',
23+
model='...',
24+
deps_type=AgentDeps,
25+
)
26+
27+
28+
@agent.tool
29+
def tool_1(ctx: RunContext[DepsA]) -> int:
30+
return ctx.deps.a
31+
32+
33+
@agent.tool
34+
def tool_2(ctx: RunContext[DepsB]) -> str:
35+
return ctx.deps.b
36+
37+
38+
# Ensure that you can use tools with deps that are supertypes of the agent's deps
39+
agent.run_sync('...', deps=AgentDeps(a=0, b='test'))

0 commit comments

Comments
 (0)