1+ import os
2+ import json
3+ from typing import Any , Optional
4+
5+ try :
6+ import boto3
7+ import botocore
8+ except ImportError :
9+ msg = "boto3 is not installed. Please install it with `pip install any-llm-sdk[aws]`"
10+ raise ImportError (msg )
11+
12+ from openai .types .chat .chat_completion import ChatCompletion , Choice
13+ from openai .types .completion_usage import CompletionUsage
14+ from openai .types .chat .chat_completion_message import ChatCompletionMessage
15+ from openai .types .chat .chat_completion_message_tool_call import ChatCompletionMessageToolCall , Function
16+ from any_llm .provider import Provider , ApiConfig
17+ from any_llm .exceptions import MissingApiKeyError
18+
19+
20+ INFERENCE_PARAMETERS = ["maxTokens" , "temperature" , "topP" , "stopSequences" ]
21+
22+
23+ def _convert_kwargs (kwargs : dict [str , Any ]) -> dict [str , Any ]:
24+ """Format the kwargs for AWS Bedrock."""
25+ kwargs = kwargs .copy ()
26+
27+ # Convert tools and remove from kwargs
28+ tool_config = _convert_tool_spec (kwargs )
29+ kwargs .pop ("tools" , None ) # Remove tools from kwargs if present
30+
31+ # Prepare inference config
32+ inference_config = {
33+ key : kwargs [key ]
34+ for key in INFERENCE_PARAMETERS
35+ if key in kwargs
36+ }
37+
38+ additional_fields = {
39+ key : value
40+ for key , value in kwargs .items ()
41+ if key not in INFERENCE_PARAMETERS
42+ }
43+
44+ request_config = {
45+ "inferenceConfig" : inference_config ,
46+ "additionalModelRequestFields" : additional_fields ,
47+ }
48+
49+ if tool_config is not None :
50+ request_config ["toolConfig" ] = tool_config
51+
52+ return request_config
53+
54+
55+ def _convert_tool_spec (kwargs : dict [str , Any ]) -> Optional [dict [str , Any ]]:
56+ """Convert tool specifications to Bedrock format."""
57+ if "tools" not in kwargs :
58+ return None
59+
60+ tool_config = {
61+ "tools" : [
62+ {
63+ "toolSpec" : {
64+ "name" : tool ["function" ]["name" ],
65+ "description" : tool ["function" ].get ("description" , " " ),
66+ "inputSchema" : {"json" : tool ["function" ]["parameters" ]},
67+ }
68+ }
69+ for tool in kwargs ["tools" ]
70+ ]
71+ }
72+ return tool_config
73+
74+
75+ def _convert_messages (messages : list [dict [str , Any ]]) -> tuple [list [dict [str , Any ]], list [dict [str , Any ]]]:
76+ """Convert messages to AWS Bedrock format."""
77+ # Handle system message
78+ system_message = []
79+ if messages and messages [0 ]["role" ] == "system" :
80+ system_message = [{"text" : messages [0 ]["content" ]}]
81+ messages = messages [1 :]
82+
83+ formatted_messages = []
84+ for message in messages :
85+ # Skip any additional system messages
86+ if message ["role" ] == "system" :
87+ continue
88+
89+ if message ["role" ] == "tool" :
90+ bedrock_message = _convert_tool_result (message )
91+ if bedrock_message :
92+ formatted_messages .append (bedrock_message )
93+ elif message ["role" ] == "assistant" :
94+ bedrock_message = _convert_assistant (message )
95+ if bedrock_message :
96+ formatted_messages .append (bedrock_message )
97+ else : # user messages
98+ formatted_messages .append ({
99+ "role" : message ["role" ],
100+ "content" : [{"text" : message ["content" ]}],
101+ })
102+
103+ return system_message , formatted_messages
104+
105+
106+ def _convert_tool_result (message : dict [str , Any ]) -> Optional [dict [str , Any ]]:
107+ """Convert OpenAI tool result format to AWS Bedrock format."""
108+ if message ["role" ] != "tool" or "content" not in message :
109+ return None
110+
111+ tool_call_id = message .get ("tool_call_id" )
112+ if not tool_call_id :
113+ raise RuntimeError ("Tool result message must include tool_call_id" )
114+
115+ try :
116+ content_json = json .loads (message ["content" ])
117+ content = [{"json" : content_json }]
118+ except json .JSONDecodeError :
119+ content = [{"text" : message ["content" ]}]
120+
121+ return {
122+ "role" : "user" ,
123+ "content" : [
124+ {"toolResult" : {"toolUseId" : tool_call_id , "content" : content }}
125+ ],
126+ }
127+
128+
129+ def _convert_assistant (message : dict [str , Any ]) -> Optional [dict [str , Any ]]:
130+ """Convert OpenAI assistant format to AWS Bedrock format."""
131+ if message ["role" ] != "assistant" :
132+ return None
133+
134+ content = []
135+
136+ if message .get ("content" ):
137+ content .append ({"text" : message ["content" ]})
138+
139+ if message .get ("tool_calls" ):
140+ for tool_call in message ["tool_calls" ]:
141+ if tool_call ["type" ] == "function" :
142+ try :
143+ input_json = json .loads (tool_call ["function" ]["arguments" ])
144+ except json .JSONDecodeError :
145+ input_json = tool_call ["function" ]["arguments" ]
146+
147+ content .append ({
148+ "toolUse" : {
149+ "toolUseId" : tool_call ["id" ],
150+ "name" : tool_call ["function" ]["name" ],
151+ "input" : input_json ,
152+ }
153+ })
154+
155+ return {"role" : "assistant" , "content" : content } if content else None
156+
157+
158+ def _convert_response (response : dict [str , Any ]) -> ChatCompletion :
159+ """Convert AWS Bedrock response to OpenAI ChatCompletion format."""
160+ # Check if the model is requesting tool use
161+ if response .get ("stopReason" ) == "tool_use" :
162+ tool_calls = []
163+ for content in response ["output" ]["message" ]["content" ]:
164+ if "toolUse" in content :
165+ tool = content ["toolUse" ]
166+ tool_calls .append (
167+ ChatCompletionMessageToolCall (
168+ id = tool ["toolUseId" ],
169+ type = "function" ,
170+ function = Function (
171+ name = tool ["name" ],
172+ arguments = json .dumps (tool ["input" ]),
173+ ),
174+ )
175+ )
176+
177+ if tool_calls :
178+ message = ChatCompletionMessage (
179+ content = None ,
180+ role = "assistant" ,
181+ tool_calls = tool_calls ,
182+ )
183+
184+ choice = Choice (
185+ finish_reason = "tool_calls" , # type: ignore
186+ index = 0 ,
187+ message = message ,
188+ )
189+
190+ usage = None
191+ if "usage" in response :
192+ usage_data = response ["usage" ]
193+ usage = CompletionUsage (
194+ completion_tokens = usage_data .get ("outputTokens" , 0 ),
195+ prompt_tokens = usage_data .get ("inputTokens" , 0 ),
196+ total_tokens = usage_data .get ("totalTokens" , 0 ),
197+ )
198+
199+ return ChatCompletion (
200+ id = response .get ("id" , "" ),
201+ model = response .get ("model" , "" ),
202+ object = "chat.completion" ,
203+ created = response .get ("created" , 0 ),
204+ choices = [choice ],
205+ usage = usage ,
206+ )
207+
208+ # Handle regular text response
209+ content = response ["output" ]["message" ]["content" ][0 ]["text" ]
210+
211+ # Map Bedrock stopReason to OpenAI finish_reason
212+ stop_reason = response .get ("stopReason" )
213+ if stop_reason == "complete" :
214+ finish_reason = "stop"
215+ elif stop_reason == "max_tokens" :
216+ finish_reason = "length"
217+ else :
218+ finish_reason = stop_reason or "stop"
219+
220+ message = ChatCompletionMessage (
221+ content = content ,
222+ role = "assistant" ,
223+ tool_calls = None ,
224+ )
225+
226+ choice = Choice (
227+ finish_reason = finish_reason , # type: ignore
228+ index = 0 ,
229+ message = message ,
230+ )
231+
232+ usage = None
233+ if "usage" in response :
234+ usage_data = response ["usage" ]
235+ usage = CompletionUsage (
236+ completion_tokens = usage_data .get ("outputTokens" , 0 ),
237+ prompt_tokens = usage_data .get ("inputTokens" , 0 ),
238+ total_tokens = usage_data .get ("totalTokens" , 0 ),
239+ )
240+
241+ return ChatCompletion (
242+ id = response .get ("id" , "" ),
243+ model = response .get ("model" , "" ),
244+ object = "chat.completion" ,
245+ created = response .get ("created" , 0 ),
246+ choices = [choice ],
247+ usage = usage ,
248+ )
249+
250+
251+ class AwsProvider (Provider ):
252+ """AWS Bedrock Provider using boto3."""
253+
254+ def __init__ (self , config : ApiConfig ) -> None :
255+ """Initialize AWS Bedrock provider."""
256+ # AWS uses region from environment variables or default
257+ self .region_name = os .getenv ("AWS_REGION" , "us-east-1" )
258+
259+ # Store config for later use
260+ self .config = config
261+
262+ # Don't create client during init to avoid test failures
263+ self .client = None
264+
265+ def completion (
266+ self ,
267+ model : str ,
268+ messages : list [dict [str , Any ]],
269+ ** kwargs : Any ,
270+ ) -> ChatCompletion :
271+ """Create a chat completion using AWS Bedrock."""
272+ # Create client if not already created
273+ if self .client is None :
274+ try :
275+ self .client = boto3 .client ("bedrock-runtime" , region_name = self .region_name )
276+ except Exception as e :
277+ raise RuntimeError (f"Failed to create AWS Bedrock client: { e } " ) from e
278+
279+ system_message , formatted_messages = _convert_messages (messages )
280+ request_config = _convert_kwargs (kwargs )
281+
282+ try :
283+ response = self .client .converse (
284+ modelId = model ,
285+ messages = formatted_messages ,
286+ system = system_message ,
287+ ** request_config ,
288+ )
289+
290+ # Convert to OpenAI format
291+ return _convert_response (response )
292+
293+ except botocore .exceptions .ClientError as e :
294+ if e .response ["Error" ]["Code" ] == "ValidationException" :
295+ error_message = e .response ["Error" ]["Message" ]
296+ raise RuntimeError (f"AWS Bedrock validation error: { error_message } " ) from e
297+ raise RuntimeError (f"AWS Bedrock API error: { e } " ) from e
298+ except Exception as e :
299+ raise RuntimeError (f"AWS Bedrock API error: { e } " ) from e
0 commit comments