Skip to content

Commit 8ede1b5

Browse files
authored
Improve variance of classes (#726)
1 parent 073983c commit 8ede1b5

File tree

7 files changed

+210
-170
lines changed

7 files changed

+210
-170
lines changed

pydantic_ai_slim/pydantic_ai/_result.py

Lines changed: 26 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,20 @@
88
from typing import Any, Callable, Generic, Literal, Union, cast, get_args, get_origin
99

1010
from pydantic import TypeAdapter, ValidationError
11-
from typing_extensions import Self, TypeAliasType, TypedDict
11+
from typing_extensions import TypeAliasType, TypedDict, TypeVar
1212

1313
from . import _utils, messages as _messages
1414
from .exceptions import ModelRetry
15-
from .result import ResultData, ResultValidatorFunc
16-
from .tools import AgentDeps, RunContext, ToolDefinition
15+
from .result import ResultDataT, ResultDataT_inv, ResultValidatorFunc
16+
from .tools import AgentDepsT, RunContext, ToolDefinition
17+
18+
T = TypeVar('T')
19+
"""An invariant TypeVar."""
1720

1821

1922
@dataclass
20-
class ResultValidator(Generic[AgentDeps, ResultData]):
21-
function: ResultValidatorFunc[AgentDeps, ResultData]
23+
class ResultValidator(Generic[AgentDepsT, ResultDataT_inv]):
24+
function: ResultValidatorFunc[AgentDepsT, ResultDataT_inv]
2225
_takes_ctx: bool = field(init=False)
2326
_is_async: bool = field(init=False)
2427

@@ -28,10 +31,10 @@ def __post_init__(self):
2831

2932
async def validate(
3033
self,
31-
result: ResultData,
34+
result: T,
3235
tool_call: _messages.ToolCallPart | None,
33-
run_context: RunContext[AgentDeps],
34-
) -> ResultData:
36+
run_context: RunContext[AgentDepsT],
37+
) -> T:
3538
"""Validate a result but calling the function.
3639
3740
Args:
@@ -50,10 +53,10 @@ async def validate(
5053

5154
try:
5255
if self._is_async:
53-
function = cast(Callable[[Any], Awaitable[ResultData]], self.function)
56+
function = cast(Callable[[Any], Awaitable[T]], self.function)
5457
result_data = await function(*args)
5558
else:
56-
function = cast(Callable[[Any], ResultData], self.function)
59+
function = cast(Callable[[Any], T], self.function)
5760
result_data = await _utils.run_in_executor(function, *args)
5861
except ModelRetry as r:
5962
m = _messages.RetryPromptPart(content=r.message)
@@ -74,17 +77,19 @@ def __init__(self, tool_retry: _messages.RetryPromptPart):
7477

7578

7679
@dataclass
77-
class ResultSchema(Generic[ResultData]):
80+
class ResultSchema(Generic[ResultDataT]):
7881
"""Model the final response from an agent run.
7982
8083
Similar to `Tool` but for the final result of running an agent.
8184
"""
8285

83-
tools: dict[str, ResultTool[ResultData]]
86+
tools: dict[str, ResultTool[ResultDataT]]
8487
allow_text_result: bool
8588

8689
@classmethod
87-
def build(cls, response_type: type[ResultData], name: str, description: str | None) -> Self | None:
90+
def build(
91+
cls: type[ResultSchema[T]], response_type: type[T], name: str, description: str | None
92+
) -> ResultSchema[T] | None:
8893
"""Build a ResultSchema dataclass from a response type."""
8994
if response_type is str:
9095
return None
@@ -95,10 +100,10 @@ def build(cls, response_type: type[ResultData], name: str, description: str | No
95100
else:
96101
allow_text_result = False
97102

98-
def _build_tool(a: Any, tool_name_: str, multiple: bool) -> ResultTool[ResultData]:
99-
return cast(ResultTool[ResultData], ResultTool(a, tool_name_, description, multiple))
103+
def _build_tool(a: Any, tool_name_: str, multiple: bool) -> ResultTool[T]:
104+
return cast(ResultTool[T], ResultTool(a, tool_name_, description, multiple))
100105

101-
tools: dict[str, ResultTool[ResultData]] = {}
106+
tools: dict[str, ResultTool[T]] = {}
102107
if args := get_union_args(response_type):
103108
for i, arg in enumerate(args, start=1):
104109
tool_name = union_tool_name(name, arg)
@@ -112,7 +117,7 @@ def _build_tool(a: Any, tool_name_: str, multiple: bool) -> ResultTool[ResultDat
112117

113118
def find_named_tool(
114119
self, parts: Iterable[_messages.ModelResponsePart], tool_name: str
115-
) -> tuple[_messages.ToolCallPart, ResultTool[ResultData]] | None:
120+
) -> tuple[_messages.ToolCallPart, ResultTool[ResultDataT]] | None:
116121
"""Find a tool that matches one of the calls, with a specific name."""
117122
for part in parts:
118123
if isinstance(part, _messages.ToolCallPart):
@@ -122,7 +127,7 @@ def find_named_tool(
122127
def find_tool(
123128
self,
124129
parts: Iterable[_messages.ModelResponsePart],
125-
) -> tuple[_messages.ToolCallPart, ResultTool[ResultData]] | None:
130+
) -> tuple[_messages.ToolCallPart, ResultTool[ResultDataT]] | None:
126131
"""Find a tool that matches one of the calls."""
127132
for part in parts:
128133
if isinstance(part, _messages.ToolCallPart):
@@ -142,11 +147,11 @@ def tool_defs(self) -> list[ToolDefinition]:
142147

143148

144149
@dataclass(init=False)
145-
class ResultTool(Generic[ResultData]):
150+
class ResultTool(Generic[ResultDataT]):
146151
tool_def: ToolDefinition
147152
type_adapter: TypeAdapter[Any]
148153

149-
def __init__(self, response_type: type[ResultData], name: str, description: str | None, multiple: bool):
154+
def __init__(self, response_type: type[ResultDataT], name: str, description: str | None, multiple: bool):
150155
"""Build a ResultTool dataclass from a response type."""
151156
assert response_type is not str, 'ResultTool does not support str as a response type'
152157

@@ -183,7 +188,7 @@ def __init__(self, response_type: type[ResultData], name: str, description: str
183188

184189
def validate(
185190
self, tool_call: _messages.ToolCallPart, allow_partial: bool = False, wrap_validation_errors: bool = True
186-
) -> ResultData:
191+
) -> ResultDataT:
187192
"""Validate a result message.
188193
189194
Args:

pydantic_ai_slim/pydantic_ai/_system_prompt.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,12 @@
66
from typing import Any, Callable, Generic, cast
77

88
from . import _utils
9-
from .tools import AgentDeps, RunContext, SystemPromptFunc
9+
from .tools import AgentDepsT, RunContext, SystemPromptFunc
1010

1111

1212
@dataclass
13-
class SystemPromptRunner(Generic[AgentDeps]):
14-
function: SystemPromptFunc[AgentDeps]
13+
class SystemPromptRunner(Generic[AgentDepsT]):
14+
function: SystemPromptFunc[AgentDepsT]
1515
dynamic: bool = False
1616
_takes_ctx: bool = field(init=False)
1717
_is_async: bool = field(init=False)
@@ -20,7 +20,7 @@ def __post_init__(self):
2020
self._takes_ctx = len(inspect.signature(self.function).parameters) > 0
2121
self._is_async = inspect.iscoroutinefunction(self.function)
2222

23-
async def run(self, run_context: RunContext[AgentDeps]) -> str:
23+
async def run(self, run_context: RunContext[AgentDepsT]) -> str:
2424
if self._takes_ctx:
2525
args = (run_context,)
2626
else:

0 commit comments

Comments
 (0)