8
8
from typing import Any , Callable , Generic , Literal , Union , cast , get_args , get_origin
9
9
10
10
from pydantic import TypeAdapter , ValidationError
11
- from typing_extensions import Self , TypeAliasType , TypedDict
11
+ from typing_extensions import TypeAliasType , TypedDict , TypeVar
12
12
13
13
from . import _utils , messages as _messages
14
14
from .exceptions import ModelRetry
15
- from .result import ResultData , ResultValidatorFunc
16
- from .tools import AgentDeps , RunContext , ToolDefinition
15
+ from .result import ResultDataT , ResultDataT_inv , ResultValidatorFunc
16
+ from .tools import AgentDepsT , RunContext , ToolDefinition
17
+
18
+ T = TypeVar ('T' )
19
+ """An invariant TypeVar."""
17
20
18
21
19
22
@dataclass
20
- class ResultValidator (Generic [AgentDeps , ResultData ]):
21
- function : ResultValidatorFunc [AgentDeps , ResultData ]
23
+ class ResultValidator (Generic [AgentDepsT , ResultDataT_inv ]):
24
+ function : ResultValidatorFunc [AgentDepsT , ResultDataT_inv ]
22
25
_takes_ctx : bool = field (init = False )
23
26
_is_async : bool = field (init = False )
24
27
@@ -28,10 +31,10 @@ def __post_init__(self):
28
31
29
32
async def validate (
30
33
self ,
31
- result : ResultData ,
34
+ result : T ,
32
35
tool_call : _messages .ToolCallPart | None ,
33
- run_context : RunContext [AgentDeps ],
34
- ) -> ResultData :
36
+ run_context : RunContext [AgentDepsT ],
37
+ ) -> T :
35
38
"""Validate a result but calling the function.
36
39
37
40
Args:
@@ -50,10 +53,10 @@ async def validate(
50
53
51
54
try :
52
55
if self ._is_async :
53
- function = cast (Callable [[Any ], Awaitable [ResultData ]], self .function )
56
+ function = cast (Callable [[Any ], Awaitable [T ]], self .function )
54
57
result_data = await function (* args )
55
58
else :
56
- function = cast (Callable [[Any ], ResultData ], self .function )
59
+ function = cast (Callable [[Any ], T ], self .function )
57
60
result_data = await _utils .run_in_executor (function , * args )
58
61
except ModelRetry as r :
59
62
m = _messages .RetryPromptPart (content = r .message )
@@ -74,17 +77,19 @@ def __init__(self, tool_retry: _messages.RetryPromptPart):
74
77
75
78
76
79
@dataclass
77
- class ResultSchema (Generic [ResultData ]):
80
+ class ResultSchema (Generic [ResultDataT ]):
78
81
"""Model the final response from an agent run.
79
82
80
83
Similar to `Tool` but for the final result of running an agent.
81
84
"""
82
85
83
- tools : dict [str , ResultTool [ResultData ]]
86
+ tools : dict [str , ResultTool [ResultDataT ]]
84
87
allow_text_result : bool
85
88
86
89
@classmethod
87
- def build (cls , response_type : type [ResultData ], name : str , description : str | None ) -> Self | None :
90
+ def build (
91
+ cls : type [ResultSchema [T ]], response_type : type [T ], name : str , description : str | None
92
+ ) -> ResultSchema [T ] | None :
88
93
"""Build a ResultSchema dataclass from a response type."""
89
94
if response_type is str :
90
95
return None
@@ -95,10 +100,10 @@ def build(cls, response_type: type[ResultData], name: str, description: str | No
95
100
else :
96
101
allow_text_result = False
97
102
98
- def _build_tool (a : Any , tool_name_ : str , multiple : bool ) -> ResultTool [ResultData ]:
99
- return cast (ResultTool [ResultData ], ResultTool (a , tool_name_ , description , multiple ))
103
+ def _build_tool (a : Any , tool_name_ : str , multiple : bool ) -> ResultTool [T ]:
104
+ return cast (ResultTool [T ], ResultTool (a , tool_name_ , description , multiple ))
100
105
101
- tools : dict [str , ResultTool [ResultData ]] = {}
106
+ tools : dict [str , ResultTool [T ]] = {}
102
107
if args := get_union_args (response_type ):
103
108
for i , arg in enumerate (args , start = 1 ):
104
109
tool_name = union_tool_name (name , arg )
@@ -112,7 +117,7 @@ def _build_tool(a: Any, tool_name_: str, multiple: bool) -> ResultTool[ResultDat
112
117
113
118
def find_named_tool (
114
119
self , parts : Iterable [_messages .ModelResponsePart ], tool_name : str
115
- ) -> tuple [_messages .ToolCallPart , ResultTool [ResultData ]] | None :
120
+ ) -> tuple [_messages .ToolCallPart , ResultTool [ResultDataT ]] | None :
116
121
"""Find a tool that matches one of the calls, with a specific name."""
117
122
for part in parts :
118
123
if isinstance (part , _messages .ToolCallPart ):
@@ -122,7 +127,7 @@ def find_named_tool(
122
127
def find_tool (
123
128
self ,
124
129
parts : Iterable [_messages .ModelResponsePart ],
125
- ) -> tuple [_messages .ToolCallPart , ResultTool [ResultData ]] | None :
130
+ ) -> tuple [_messages .ToolCallPart , ResultTool [ResultDataT ]] | None :
126
131
"""Find a tool that matches one of the calls."""
127
132
for part in parts :
128
133
if isinstance (part , _messages .ToolCallPart ):
@@ -142,11 +147,11 @@ def tool_defs(self) -> list[ToolDefinition]:
142
147
143
148
144
149
@dataclass (init = False )
145
- class ResultTool (Generic [ResultData ]):
150
+ class ResultTool (Generic [ResultDataT ]):
146
151
tool_def : ToolDefinition
147
152
type_adapter : TypeAdapter [Any ]
148
153
149
- def __init__ (self , response_type : type [ResultData ], name : str , description : str | None , multiple : bool ):
154
+ def __init__ (self , response_type : type [ResultDataT ], name : str , description : str | None , multiple : bool ):
150
155
"""Build a ResultTool dataclass from a response type."""
151
156
assert response_type is not str , 'ResultTool does not support str as a response type'
152
157
@@ -183,7 +188,7 @@ def __init__(self, response_type: type[ResultData], name: str, description: str
183
188
184
189
def validate (
185
190
self , tool_call : _messages .ToolCallPart , allow_partial : bool = False , wrap_validation_errors : bool = True
186
- ) -> ResultData :
191
+ ) -> ResultDataT :
187
192
"""Validate a result message.
188
193
189
194
Args:
0 commit comments