|
10 | 10 |
|
11 | 11 | import anyio
|
12 | 12 | import anyio.to_thread
|
| 13 | +from mypy_boto3_bedrock_runtime.type_defs import ConverseRequestTypeDef, SystemContentBlockTypeDef |
13 | 14 | from typing_extensions import ParamSpec, assert_never
|
14 | 15 |
|
15 | 16 | from pydantic_ai import _utils, result
|
@@ -258,20 +259,19 @@ async def _messages_create(
|
258 | 259 | else:
|
259 | 260 | tool_choice = {'auto': {}}
|
260 | 261 |
|
261 |
| - system_prompt, bedrock_messages = await self._map_message(messages) |
| 262 | + system_prompt, bedrock_messages = await self._map_messages(messages) |
262 | 263 | inference_config = self._map_inference_config(model_settings)
|
263 | 264 |
|
264 |
| - params = { |
| 265 | + params: ConverseRequestTypeDef = { |
265 | 266 | 'modelId': self.model_name,
|
266 | 267 | 'messages': bedrock_messages,
|
267 |
| - 'system': [{'text': system_prompt}], |
| 268 | + 'system': system_prompt, |
268 | 269 | 'inferenceConfig': inference_config,
|
269 |
| - **( |
270 |
| - {'toolConfig': {'tools': tools, **({'toolChoice': tool_choice} if tool_choice else {})}} |
271 |
| - if tools |
272 |
| - else {} |
273 |
| - ), |
274 | 270 | }
|
| 271 | + if tools: |
| 272 | + params['toolConfig'] = {'tools': tools} |
| 273 | + if tool_choice: |
| 274 | + params['toolConfig']['toolChoice'] = tool_choice |
275 | 275 |
|
276 | 276 | if stream:
|
277 | 277 | model_response = await anyio.to_thread.run_sync(functools.partial(self.client.converse_stream, **params))
|
@@ -299,15 +299,17 @@ def _map_inference_config(
|
299 | 299 |
|
300 | 300 | return inference_config
|
301 | 301 |
|
302 |
| - async def _map_message(self, messages: list[ModelMessage]) -> tuple[str, list[MessageUnionTypeDef]]: |
| 302 | + async def _map_messages( |
| 303 | + self, messages: list[ModelMessage] |
| 304 | + ) -> tuple[list[SystemContentBlockTypeDef], list[MessageUnionTypeDef]]: |
303 | 305 | """Just maps a `pydantic_ai.Message` to the Bedrock `MessageUnionTypeDef`."""
|
304 |
| - system_prompt: str = '' |
| 306 | + system_prompt: list[SystemContentBlockTypeDef] = [] |
305 | 307 | bedrock_messages: list[MessageUnionTypeDef] = []
|
306 | 308 | for m in messages:
|
307 | 309 | if isinstance(m, ModelRequest):
|
308 | 310 | for part in m.parts:
|
309 | 311 | if isinstance(part, SystemPromptPart):
|
310 |
| - system_prompt += part.content |
| 312 | + system_prompt.append({'text': part.content}) |
311 | 313 | elif isinstance(part, UserPromptPart):
|
312 | 314 | bedrock_messages.extend(await self._map_user_prompt(part))
|
313 | 315 | elif isinstance(part, ToolReturnPart):
|
|
0 commit comments