34
34
from .function import _estimate_string_tokens , _estimate_usage # pyright: ignore[reportPrivateUsage]
35
35
36
36
37
+ @dataclass
38
+ class _TextResult :
39
+ """A private wrapper class to tag a result that came from the custom_result_text field."""
40
+
41
+ value : str | None
42
+
43
+
44
+ @dataclass
45
+ class _FunctionToolResult :
46
+ """A wrapper class to tag a result that came from the custom_result_args field."""
47
+
48
+ value : Any | None
49
+
50
+
37
51
@dataclass
38
52
class TestModel (Model ):
39
53
"""A model specifically for testing purposes.
@@ -53,7 +67,7 @@ class TestModel(Model):
53
67
call_tools : list [str ] | Literal ['all' ] = 'all'
54
68
"""List of tools to call. If `'all'`, all tools will be called."""
55
69
custom_result_text : str | None = None
56
- """If set, this text is return as the final result."""
70
+ """If set, this text is returned as the final result."""
57
71
custom_result_args : Any | None = None
58
72
"""If set, these args will be passed to the result tool."""
59
73
seed : int = 0
@@ -95,21 +109,21 @@ async def agent_model(
95
109
if self .custom_result_text is not None :
96
110
assert allow_text_result , 'Plain response not allowed, but `custom_result_text` is set.'
97
111
assert self .custom_result_args is None , 'Cannot set both `custom_result_text` and `custom_result_args`.'
98
- result : _utils . Either [ str | None , Any | None ] = _utils . Either ( left = self .custom_result_text )
112
+ result : _TextResult | _FunctionToolResult = _TextResult ( self .custom_result_text )
99
113
elif self .custom_result_args is not None :
100
114
assert result_tools is not None , 'No result tools provided, but `custom_result_args` is set.'
101
115
result_tool = result_tools [0 ]
102
116
103
117
if k := result_tool .outer_typed_dict_key :
104
- result = _utils . Either ( right = {k : self .custom_result_args })
118
+ result = _FunctionToolResult ( {k : self .custom_result_args })
105
119
else :
106
- result = _utils . Either ( right = self .custom_result_args )
120
+ result = _FunctionToolResult ( self .custom_result_args )
107
121
elif allow_text_result :
108
- result = _utils . Either ( left = None )
122
+ result = _TextResult ( None )
109
123
elif result_tools :
110
- result = _utils . Either ( right = None )
124
+ result = _FunctionToolResult ( None )
111
125
else :
112
- result = _utils . Either ( left = None )
126
+ result = _TextResult ( None )
113
127
114
128
return TestAgentModel (tool_calls , result , result_tools , self .seed )
115
129
@@ -126,7 +140,7 @@ class TestAgentModel(AgentModel):
126
140
127
141
tool_calls : list [tuple [str , ToolDefinition ]]
128
142
# left means the text is plain text; right means it's a function call
129
- result : _utils . Either [ str | None , Any | None ]
143
+ result : _TextResult | _FunctionToolResult
130
144
result_tools : list [ToolDefinition ]
131
145
seed : int
132
146
model_name : str = 'test'
@@ -176,16 +190,18 @@ def _request(self, messages: list[ModelMessage], model_settings: ModelSettings |
176
190
[
177
191
ToolCallPart .from_raw_args (
178
192
tool .name ,
179
- self .result .right if self .result .right is not None else self .gen_tool_args (tool ),
193
+ self .result .value
194
+ if isinstance (self .result , _FunctionToolResult ) and self .result .value is not None
195
+ else self .gen_tool_args (tool ),
180
196
)
181
197
for tool in self .result_tools
182
198
if tool .name in new_retry_names
183
199
]
184
200
)
185
201
return ModelResponse (parts = retry_parts , model_name = self .model_name )
186
202
187
- if response_text := self .result . left :
188
- if response_text . value is None :
203
+ if isinstance ( self .result , _TextResult ) :
204
+ if ( response_text := self . result . value ) is None :
189
205
# build up details of tool responses
190
206
output : dict [str , Any ] = {}
191
207
for message in messages :
@@ -200,10 +216,10 @@ def _request(self, messages: list[ModelMessage], model_settings: ModelSettings |
200
216
else :
201
217
return ModelResponse (parts = [TextPart ('success (no tool calls)' )], model_name = self .model_name )
202
218
else :
203
- return ModelResponse (parts = [TextPart (response_text . value )], model_name = self .model_name )
219
+ return ModelResponse (parts = [TextPart (response_text )], model_name = self .model_name )
204
220
else :
205
221
assert self .result_tools , 'No result tools provided'
206
- custom_result_args = self .result .right
222
+ custom_result_args = self .result .value
207
223
result_tool = self .result_tools [self .seed % len (self .result_tools )]
208
224
if custom_result_args is not None :
209
225
return ModelResponse (
0 commit comments