11import os
22import json
3- from typing import Any , Optional
3+ from typing import Any , Optional , Literal
44
55try :
66 import boto3
2323def _convert_kwargs (kwargs : dict [str , Any ]) -> dict [str , Any ]:
2424 """Format the kwargs for AWS Bedrock."""
2525 kwargs = kwargs .copy ()
26-
26+
2727 # Convert tools and remove from kwargs
2828 tool_config = _convert_tool_spec (kwargs )
2929 kwargs .pop ("tools" , None ) # Remove tools from kwargs if present
30-
30+
3131 # Prepare inference config
32- inference_config = {
33- key : kwargs [key ]
34- for key in INFERENCE_PARAMETERS
35- if key in kwargs
36- }
37-
38- additional_fields = {
39- key : value
40- for key , value in kwargs .items ()
41- if key not in INFERENCE_PARAMETERS
42- }
43-
32+ inference_config = {key : kwargs [key ] for key in INFERENCE_PARAMETERS if key in kwargs }
33+
34+ additional_fields = {key : value for key , value in kwargs .items () if key not in INFERENCE_PARAMETERS }
35+
4436 request_config = {
4537 "inferenceConfig" : inference_config ,
4638 "additionalModelRequestFields" : additional_fields ,
4739 }
48-
40+
4941 if tool_config is not None :
5042 request_config ["toolConfig" ] = tool_config
51-
43+
5244 return request_config
5345
5446
5547def _convert_tool_spec (kwargs : dict [str , Any ]) -> Optional [dict [str , Any ]]:
5648 """Convert tool specifications to Bedrock format."""
5749 if "tools" not in kwargs :
5850 return None
59-
51+
6052 tool_config = {
6153 "tools" : [
6254 {
@@ -79,13 +71,13 @@ def _convert_messages(messages: list[dict[str, Any]]) -> tuple[list[dict[str, An
7971 if messages and messages [0 ]["role" ] == "system" :
8072 system_message = [{"text" : messages [0 ]["content" ]}]
8173 messages = messages [1 :]
82-
74+
8375 formatted_messages = []
8476 for message in messages :
8577 # Skip any additional system messages
8678 if message ["role" ] == "system" :
8779 continue
88-
80+
8981 if message ["role" ] == "tool" :
9082 bedrock_message = _convert_tool_result (message )
9183 if bedrock_message :
@@ -95,63 +87,65 @@ def _convert_messages(messages: list[dict[str, Any]]) -> tuple[list[dict[str, An
9587 if bedrock_message :
9688 formatted_messages .append (bedrock_message )
9789 else : # user messages
98- formatted_messages .append ({
99- "role" : message ["role" ],
100- "content" : [{"text" : message ["content" ]}],
101- })
102-
90+ formatted_messages .append (
91+ {
92+ "role" : message ["role" ],
93+ "content" : [{"text" : message ["content" ]}],
94+ }
95+ )
96+
10397 return system_message , formatted_messages
10498
10599
106100def _convert_tool_result (message : dict [str , Any ]) -> Optional [dict [str , Any ]]:
107101 """Convert OpenAI tool result format to AWS Bedrock format."""
108102 if message ["role" ] != "tool" or "content" not in message :
109103 return None
110-
104+
111105 tool_call_id = message .get ("tool_call_id" )
112106 if not tool_call_id :
113107 raise RuntimeError ("Tool result message must include tool_call_id" )
114-
108+
115109 try :
116110 content_json = json .loads (message ["content" ])
117111 content = [{"json" : content_json }]
118112 except json .JSONDecodeError :
119113 content = [{"text" : message ["content" ]}]
120-
114+
121115 return {
122116 "role" : "user" ,
123- "content" : [
124- {"toolResult" : {"toolUseId" : tool_call_id , "content" : content }}
125- ],
117+ "content" : [{"toolResult" : {"toolUseId" : tool_call_id , "content" : content }}],
126118 }
127119
128120
129121def _convert_assistant (message : dict [str , Any ]) -> Optional [dict [str , Any ]]:
130122 """Convert OpenAI assistant format to AWS Bedrock format."""
131123 if message ["role" ] != "assistant" :
132124 return None
133-
125+
134126 content = []
135-
127+
136128 if message .get ("content" ):
137129 content .append ({"text" : message ["content" ]})
138-
130+
139131 if message .get ("tool_calls" ):
140132 for tool_call in message ["tool_calls" ]:
141133 if tool_call ["type" ] == "function" :
142134 try :
143135 input_json = json .loads (tool_call ["function" ]["arguments" ])
144136 except json .JSONDecodeError :
145137 input_json = tool_call ["function" ]["arguments" ]
146-
147- content .append ({
148- "toolUse" : {
149- "toolUseId" : tool_call ["id" ],
150- "name" : tool_call ["function" ]["name" ],
151- "input" : input_json ,
138+
139+ content .append (
140+ {
141+ "toolUse" : {
142+ "toolUseId" : tool_call ["id" ],
143+ "name" : tool_call ["function" ]["name" ],
144+ "input" : input_json ,
145+ }
152146 }
153- } )
154-
147+ )
148+
155149 return {"role" : "assistant" , "content" : content } if content else None
156150
157151
@@ -173,20 +167,20 @@ def _convert_response(response: dict[str, Any]) -> ChatCompletion:
173167 ),
174168 )
175169 )
176-
170+
177171 if tool_calls :
178172 message = ChatCompletionMessage (
179173 content = None ,
180174 role = "assistant" ,
181175 tool_calls = tool_calls ,
182176 )
183-
177+
184178 choice = Choice (
185- finish_reason = "tool_calls" , # type: ignore
179+ finish_reason = "tool_calls" ,
186180 index = 0 ,
187181 message = message ,
188182 )
189-
183+
190184 usage = None
191185 if "usage" in response :
192186 usage_data = response ["usage" ]
@@ -195,7 +189,7 @@ def _convert_response(response: dict[str, Any]) -> ChatCompletion:
195189 prompt_tokens = usage_data .get ("inputTokens" , 0 ),
196190 total_tokens = usage_data .get ("totalTokens" , 0 ),
197191 )
198-
192+
199193 return ChatCompletion (
200194 id = response .get ("id" , "" ),
201195 model = response .get ("model" , "" ),
@@ -204,31 +198,32 @@ def _convert_response(response: dict[str, Any]) -> ChatCompletion:
204198 choices = [choice ],
205199 usage = usage ,
206200 )
207-
201+
208202 # Handle regular text response
209203 content = response ["output" ]["message" ]["content" ][0 ]["text" ]
210-
204+
211205 # Map Bedrock stopReason to OpenAI finish_reason
212206 stop_reason = response .get ("stopReason" )
207+ finish_reason : Literal ["stop" , "length" , "tool_calls" , "content_filter" , "function_call" ]
213208 if stop_reason == "complete" :
214209 finish_reason = "stop"
215210 elif stop_reason == "max_tokens" :
216211 finish_reason = "length"
217212 else :
218- finish_reason = stop_reason or "stop"
219-
213+ finish_reason = "stop"
214+
220215 message = ChatCompletionMessage (
221216 content = content ,
222217 role = "assistant" ,
223218 tool_calls = None ,
224219 )
225-
220+
226221 choice = Choice (
227- finish_reason = finish_reason , # type: ignore
222+ finish_reason = finish_reason ,
228223 index = 0 ,
229224 message = message ,
230225 )
231-
226+
232227 usage = None
233228 if "usage" in response :
234229 usage_data = response ["usage" ]
@@ -237,7 +232,7 @@ def _convert_response(response: dict[str, Any]) -> ChatCompletion:
237232 prompt_tokens = usage_data .get ("inputTokens" , 0 ),
238233 total_tokens = usage_data .get ("totalTokens" , 0 ),
239234 )
240-
235+
241236 return ChatCompletion (
242237 id = response .get ("id" , "" ),
243238 model = response .get ("model" , "" ),
@@ -255,12 +250,36 @@ def __init__(self, config: ApiConfig) -> None:
255250 """Initialize AWS Bedrock provider."""
256251 # AWS uses region from environment variables or default
257252 self .region_name = os .getenv ("AWS_REGION" , "us-east-1" )
258-
253+
259254 # Store config for later use
260255 self .config = config
261-
256+
262257 # Don't create client during init to avoid test failures
263- self .client = None
258+ self .client : Optional [Any ] = None
259+
260+ def _check_aws_credentials (self ) -> None :
261+ """Check if AWS credentials are available."""
262+ try :
263+ # Create a session to check if credentials are available
264+ session = boto3 .Session () # type: ignore[no-untyped-call, attr-defined]
265+ credentials = session .get_credentials () # type: ignore[no-untyped-call]
266+
267+ if credentials is None :
268+ raise MissingApiKeyError (
269+ provider_name = "AWS" , env_var_name = "AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY"
270+ )
271+
272+ # Try to get the credentials to ensure they're not expired/invalid
273+ _ = credentials .access_key
274+ _ = credentials .secret_key
275+
276+ except Exception as e :
277+ if isinstance (e , MissingApiKeyError ):
278+ raise
279+ # If any other error occurs while checking credentials, treat as missing
280+ raise MissingApiKeyError (
281+ provider_name = "AWS" , env_var_name = "AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY"
282+ ) from e
264283
265284 def completion (
266285 self ,
@@ -269,31 +288,35 @@ def completion(
269288 ** kwargs : Any ,
270289 ) -> ChatCompletion :
271290 """Create a chat completion using AWS Bedrock."""
291+ # Check credentials before creating client
292+ self ._check_aws_credentials ()
293+
272294 # Create client if not already created
273295 if self .client is None :
274296 try :
275- self .client = boto3 .client ("bedrock-runtime" , region_name = self .region_name )
297+ self .client = boto3 .client ("bedrock-runtime" , region_name = self .region_name ) # type: ignore[no-untyped-call]
276298 except Exception as e :
277299 raise RuntimeError (f"Failed to create AWS Bedrock client: { e } " ) from e
278-
300+
279301 system_message , formatted_messages = _convert_messages (messages )
280302 request_config = _convert_kwargs (kwargs )
281-
303+
282304 try :
305+ assert self .client is not None # For mypy
283306 response = self .client .converse (
284307 modelId = model ,
285308 messages = formatted_messages ,
286309 system = system_message ,
287310 ** request_config ,
288311 )
289-
312+
290313 # Convert to OpenAI format
291314 return _convert_response (response )
292-
315+
293316 except botocore .exceptions .ClientError as e :
294317 if e .response ["Error" ]["Code" ] == "ValidationException" :
295318 error_message = e .response ["Error" ]["Message" ]
296319 raise RuntimeError (f"AWS Bedrock validation error: { error_message } " ) from e
297320 raise RuntimeError (f"AWS Bedrock API error: { e } " ) from e
298321 except Exception as e :
299- raise RuntimeError (f"AWS Bedrock API error: { e } " ) from e
322+ raise RuntimeError (f"AWS Bedrock API error: { e } " ) from e
0 commit comments