28
28
from ..settings import ModelSettings
29
29
from ..tools import ToolDefinition
30
30
from . import (
31
- AgentModel ,
32
31
Model ,
32
+ ModelRequestParameters ,
33
33
StreamedResponse ,
34
34
cached_async_http_client ,
35
35
check_allow_model_requests ,
@@ -134,81 +134,70 @@ def __init__(
134
134
else :
135
135
self .client = AsyncAnthropic (api_key = api_key , http_client = cached_async_http_client ())
136
136
137
- async def agent_model (
138
- self ,
139
- * ,
140
- function_tools : list [ToolDefinition ],
141
- allow_text_result : bool ,
142
- result_tools : list [ToolDefinition ],
143
- ) -> AgentModel :
144
- check_allow_model_requests ()
145
- tools = [self ._map_tool_definition (r ) for r in function_tools ]
146
- if result_tools :
147
- tools += [self ._map_tool_definition (r ) for r in result_tools ]
148
- return AnthropicAgentModel (
149
- self .client ,
150
- self .model_name ,
151
- allow_text_result ,
152
- tools ,
153
- )
154
-
155
137
def name (self ) -> str :
156
138
return f'anthropic:{ self .model_name } '
157
139
158
- @staticmethod
159
- def _map_tool_definition (f : ToolDefinition ) -> ToolParam :
160
- return {
161
- 'name' : f .name ,
162
- 'description' : f .description ,
163
- 'input_schema' : f .parameters_json_schema ,
164
- }
165
-
166
-
167
- @dataclass
168
- class AnthropicAgentModel (AgentModel ):
169
- """Implementation of `AgentModel` for Anthropic models."""
170
-
171
- client : AsyncAnthropic
172
- model_name : AnthropicModelName
173
- allow_text_result : bool
174
- tools : list [ToolParam ]
175
-
176
140
async def request (
177
- self , messages : list [ModelMessage ], model_settings : ModelSettings | None
141
+ self ,
142
+ messages : list [ModelMessage ],
143
+ model_settings : ModelSettings | None ,
144
+ model_request_parameters : ModelRequestParameters ,
178
145
) -> tuple [ModelResponse , usage .Usage ]:
179
- response = await self ._messages_create (messages , False , cast (AnthropicModelSettings , model_settings or {}))
146
+ check_allow_model_requests ()
147
+ response = await self ._messages_create (
148
+ messages , False , cast (AnthropicModelSettings , model_settings or {}), model_request_parameters
149
+ )
180
150
return self ._process_response (response ), _map_usage (response )
181
151
182
152
@asynccontextmanager
183
153
async def request_stream (
184
- self , messages : list [ModelMessage ], model_settings : ModelSettings | None
154
+ self ,
155
+ messages : list [ModelMessage ],
156
+ model_settings : ModelSettings | None ,
157
+ model_request_parameters : ModelRequestParameters ,
185
158
) -> AsyncIterator [StreamedResponse ]:
186
- response = await self ._messages_create (messages , True , cast (AnthropicModelSettings , model_settings or {}))
159
+ check_allow_model_requests ()
160
+ response = await self ._messages_create (
161
+ messages , True , cast (AnthropicModelSettings , model_settings or {}), model_request_parameters
162
+ )
187
163
async with response :
188
164
yield await self ._process_streamed_response (response )
189
165
190
166
@overload
191
167
async def _messages_create (
192
- self , messages : list [ModelMessage ], stream : Literal [True ], model_settings : AnthropicModelSettings
168
+ self ,
169
+ messages : list [ModelMessage ],
170
+ stream : Literal [True ],
171
+ model_settings : AnthropicModelSettings ,
172
+ model_request_parameters : ModelRequestParameters ,
193
173
) -> AsyncStream [RawMessageStreamEvent ]:
194
174
pass
195
175
196
176
@overload
197
177
async def _messages_create (
198
- self , messages : list [ModelMessage ], stream : Literal [False ], model_settings : AnthropicModelSettings
178
+ self ,
179
+ messages : list [ModelMessage ],
180
+ stream : Literal [False ],
181
+ model_settings : AnthropicModelSettings ,
182
+ model_request_parameters : ModelRequestParameters ,
199
183
) -> AnthropicMessage :
200
184
pass
201
185
202
186
async def _messages_create (
203
- self , messages : list [ModelMessage ], stream : bool , model_settings : AnthropicModelSettings
187
+ self ,
188
+ messages : list [ModelMessage ],
189
+ stream : bool ,
190
+ model_settings : AnthropicModelSettings ,
191
+ model_request_parameters : ModelRequestParameters ,
204
192
) -> AnthropicMessage | AsyncStream [RawMessageStreamEvent ]:
205
193
# standalone function to make it easier to override
194
+ tools = self ._get_tools (model_request_parameters )
206
195
tool_choice : ToolChoiceParam | None
207
196
208
- if not self . tools :
197
+ if not tools :
209
198
tool_choice = None
210
199
else :
211
- if not self .allow_text_result :
200
+ if not model_request_parameters .allow_text_result :
212
201
tool_choice = {'type' : 'any' }
213
202
else :
214
203
tool_choice = {'type' : 'auto' }
@@ -223,7 +212,7 @@ async def _messages_create(
223
212
system = system_prompt or NOT_GIVEN ,
224
213
messages = anthropic_messages ,
225
214
model = self .model_name ,
226
- tools = self . tools or NOT_GIVEN ,
215
+ tools = tools or NOT_GIVEN ,
227
216
tool_choice = tool_choice or NOT_GIVEN ,
228
217
stream = stream ,
229
218
temperature = model_settings .get ('temperature' , NOT_GIVEN ),
@@ -260,8 +249,13 @@ async def _process_streamed_response(self, response: AsyncStream[RawMessageStrea
260
249
timestamp = datetime .now (tz = timezone .utc )
261
250
return AnthropicStreamedResponse (_model_name = self .model_name , _response = peekable_response , _timestamp = timestamp )
262
251
263
- @staticmethod
264
- def _map_message (messages : list [ModelMessage ]) -> tuple [str , list [MessageParam ]]:
252
+ def _get_tools (self , model_request_parameters : ModelRequestParameters ) -> list [ToolParam ]:
253
+ tools = [self ._map_tool_definition (r ) for r in model_request_parameters .function_tools ]
254
+ if model_request_parameters .result_tools :
255
+ tools += [self ._map_tool_definition (r ) for r in model_request_parameters .result_tools ]
256
+ return tools
257
+
258
+ def _map_message (self , messages : list [ModelMessage ]) -> tuple [str , list [MessageParam ]]:
265
259
"""Just maps a `pydantic_ai.Message` to a `anthropic.types.MessageParam`."""
266
260
system_prompt : str = ''
267
261
anthropic_messages : list [MessageParam ] = []
@@ -310,20 +304,28 @@ def _map_message(messages: list[ModelMessage]) -> tuple[str, list[MessageParam]]
310
304
content .append (TextBlockParam (text = item .content , type = 'text' ))
311
305
else :
312
306
assert isinstance (item , ToolCallPart )
313
- content .append (_map_tool_call (item ))
307
+ content .append (self . _map_tool_call (item ))
314
308
anthropic_messages .append (MessageParam (role = 'assistant' , content = content ))
315
309
else :
316
310
assert_never (m )
317
311
return system_prompt , anthropic_messages
318
312
313
+ @staticmethod
314
+ def _map_tool_call (t : ToolCallPart ) -> ToolUseBlockParam :
315
+ return ToolUseBlockParam (
316
+ id = _guard_tool_call_id (t = t , model_source = 'Anthropic' ),
317
+ type = 'tool_use' ,
318
+ name = t .tool_name ,
319
+ input = t .args_as_dict (),
320
+ )
319
321
320
- def _map_tool_call ( t : ToolCallPart ) -> ToolUseBlockParam :
321
- return ToolUseBlockParam (
322
- id = _guard_tool_call_id ( t = t , model_source = 'Anthropic' ),
323
- type = 'tool_use' ,
324
- name = t . tool_name ,
325
- input = t . args_as_dict () ,
326
- )
322
+ @ staticmethod
323
+ def _map_tool_definition ( f : ToolDefinition ) -> ToolParam :
324
+ return {
325
+ 'name' : f . name ,
326
+ 'description' : f . description ,
327
+ 'input_schema' : f . parameters_json_schema ,
328
+ }
327
329
328
330
329
331
def _map_usage (message : AnthropicMessage | RawMessageStreamEvent ) -> usage .Usage :
0 commit comments