Skip to content

Commit b8df34b

Browse files
committed
lint
1 parent 97f1f30 commit b8df34b

File tree

14 files changed

+213
-198
lines changed

14 files changed

+213
-198
lines changed
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
from .aws import AwsProvider
22

3-
__all__ = ["AwsProvider"]
3+
__all__ = ["AwsProvider"]

src/any_llm/providers/aws/aws.py

Lines changed: 87 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import os
22
import json
3-
from typing import Any, Optional
3+
from typing import Any, Optional, Literal
44

55
try:
66
import boto3
@@ -23,40 +23,32 @@
2323
def _convert_kwargs(kwargs: dict[str, Any]) -> dict[str, Any]:
2424
"""Format the kwargs for AWS Bedrock."""
2525
kwargs = kwargs.copy()
26-
26+
2727
# Convert tools and remove from kwargs
2828
tool_config = _convert_tool_spec(kwargs)
2929
kwargs.pop("tools", None) # Remove tools from kwargs if present
30-
30+
3131
# Prepare inference config
32-
inference_config = {
33-
key: kwargs[key]
34-
for key in INFERENCE_PARAMETERS
35-
if key in kwargs
36-
}
37-
38-
additional_fields = {
39-
key: value
40-
for key, value in kwargs.items()
41-
if key not in INFERENCE_PARAMETERS
42-
}
43-
32+
inference_config = {key: kwargs[key] for key in INFERENCE_PARAMETERS if key in kwargs}
33+
34+
additional_fields = {key: value for key, value in kwargs.items() if key not in INFERENCE_PARAMETERS}
35+
4436
request_config = {
4537
"inferenceConfig": inference_config,
4638
"additionalModelRequestFields": additional_fields,
4739
}
48-
40+
4941
if tool_config is not None:
5042
request_config["toolConfig"] = tool_config
51-
43+
5244
return request_config
5345

5446

5547
def _convert_tool_spec(kwargs: dict[str, Any]) -> Optional[dict[str, Any]]:
5648
"""Convert tool specifications to Bedrock format."""
5749
if "tools" not in kwargs:
5850
return None
59-
51+
6052
tool_config = {
6153
"tools": [
6254
{
@@ -79,13 +71,13 @@ def _convert_messages(messages: list[dict[str, Any]]) -> tuple[list[dict[str, An
7971
if messages and messages[0]["role"] == "system":
8072
system_message = [{"text": messages[0]["content"]}]
8173
messages = messages[1:]
82-
74+
8375
formatted_messages = []
8476
for message in messages:
8577
# Skip any additional system messages
8678
if message["role"] == "system":
8779
continue
88-
80+
8981
if message["role"] == "tool":
9082
bedrock_message = _convert_tool_result(message)
9183
if bedrock_message:
@@ -95,63 +87,65 @@ def _convert_messages(messages: list[dict[str, Any]]) -> tuple[list[dict[str, An
9587
if bedrock_message:
9688
formatted_messages.append(bedrock_message)
9789
else: # user messages
98-
formatted_messages.append({
99-
"role": message["role"],
100-
"content": [{"text": message["content"]}],
101-
})
102-
90+
formatted_messages.append(
91+
{
92+
"role": message["role"],
93+
"content": [{"text": message["content"]}],
94+
}
95+
)
96+
10397
return system_message, formatted_messages
10498

10599

106100
def _convert_tool_result(message: dict[str, Any]) -> Optional[dict[str, Any]]:
107101
"""Convert OpenAI tool result format to AWS Bedrock format."""
108102
if message["role"] != "tool" or "content" not in message:
109103
return None
110-
104+
111105
tool_call_id = message.get("tool_call_id")
112106
if not tool_call_id:
113107
raise RuntimeError("Tool result message must include tool_call_id")
114-
108+
115109
try:
116110
content_json = json.loads(message["content"])
117111
content = [{"json": content_json}]
118112
except json.JSONDecodeError:
119113
content = [{"text": message["content"]}]
120-
114+
121115
return {
122116
"role": "user",
123-
"content": [
124-
{"toolResult": {"toolUseId": tool_call_id, "content": content}}
125-
],
117+
"content": [{"toolResult": {"toolUseId": tool_call_id, "content": content}}],
126118
}
127119

128120

129121
def _convert_assistant(message: dict[str, Any]) -> Optional[dict[str, Any]]:
130122
"""Convert OpenAI assistant format to AWS Bedrock format."""
131123
if message["role"] != "assistant":
132124
return None
133-
125+
134126
content = []
135-
127+
136128
if message.get("content"):
137129
content.append({"text": message["content"]})
138-
130+
139131
if message.get("tool_calls"):
140132
for tool_call in message["tool_calls"]:
141133
if tool_call["type"] == "function":
142134
try:
143135
input_json = json.loads(tool_call["function"]["arguments"])
144136
except json.JSONDecodeError:
145137
input_json = tool_call["function"]["arguments"]
146-
147-
content.append({
148-
"toolUse": {
149-
"toolUseId": tool_call["id"],
150-
"name": tool_call["function"]["name"],
151-
"input": input_json,
138+
139+
content.append(
140+
{
141+
"toolUse": {
142+
"toolUseId": tool_call["id"],
143+
"name": tool_call["function"]["name"],
144+
"input": input_json,
145+
}
152146
}
153-
})
154-
147+
)
148+
155149
return {"role": "assistant", "content": content} if content else None
156150

157151

@@ -173,20 +167,20 @@ def _convert_response(response: dict[str, Any]) -> ChatCompletion:
173167
),
174168
)
175169
)
176-
170+
177171
if tool_calls:
178172
message = ChatCompletionMessage(
179173
content=None,
180174
role="assistant",
181175
tool_calls=tool_calls,
182176
)
183-
177+
184178
choice = Choice(
185-
finish_reason="tool_calls", # type: ignore
179+
finish_reason="tool_calls",
186180
index=0,
187181
message=message,
188182
)
189-
183+
190184
usage = None
191185
if "usage" in response:
192186
usage_data = response["usage"]
@@ -195,7 +189,7 @@ def _convert_response(response: dict[str, Any]) -> ChatCompletion:
195189
prompt_tokens=usage_data.get("inputTokens", 0),
196190
total_tokens=usage_data.get("totalTokens", 0),
197191
)
198-
192+
199193
return ChatCompletion(
200194
id=response.get("id", ""),
201195
model=response.get("model", ""),
@@ -204,31 +198,32 @@ def _convert_response(response: dict[str, Any]) -> ChatCompletion:
204198
choices=[choice],
205199
usage=usage,
206200
)
207-
201+
208202
# Handle regular text response
209203
content = response["output"]["message"]["content"][0]["text"]
210-
204+
211205
# Map Bedrock stopReason to OpenAI finish_reason
212206
stop_reason = response.get("stopReason")
207+
finish_reason: Literal["stop", "length", "tool_calls", "content_filter", "function_call"]
213208
if stop_reason == "complete":
214209
finish_reason = "stop"
215210
elif stop_reason == "max_tokens":
216211
finish_reason = "length"
217212
else:
218-
finish_reason = stop_reason or "stop"
219-
213+
finish_reason = "stop"
214+
220215
message = ChatCompletionMessage(
221216
content=content,
222217
role="assistant",
223218
tool_calls=None,
224219
)
225-
220+
226221
choice = Choice(
227-
finish_reason=finish_reason, # type: ignore
222+
finish_reason=finish_reason,
228223
index=0,
229224
message=message,
230225
)
231-
226+
232227
usage = None
233228
if "usage" in response:
234229
usage_data = response["usage"]
@@ -237,7 +232,7 @@ def _convert_response(response: dict[str, Any]) -> ChatCompletion:
237232
prompt_tokens=usage_data.get("inputTokens", 0),
238233
total_tokens=usage_data.get("totalTokens", 0),
239234
)
240-
235+
241236
return ChatCompletion(
242237
id=response.get("id", ""),
243238
model=response.get("model", ""),
@@ -255,12 +250,36 @@ def __init__(self, config: ApiConfig) -> None:
255250
"""Initialize AWS Bedrock provider."""
256251
# AWS uses region from environment variables or default
257252
self.region_name = os.getenv("AWS_REGION", "us-east-1")
258-
253+
259254
# Store config for later use
260255
self.config = config
261-
256+
262257
# Don't create client during init to avoid test failures
263-
self.client = None
258+
self.client: Optional[Any] = None
259+
260+
def _check_aws_credentials(self) -> None:
261+
"""Check if AWS credentials are available."""
262+
try:
263+
# Create a session to check if credentials are available
264+
session = boto3.Session() # type: ignore[no-untyped-call, attr-defined]
265+
credentials = session.get_credentials() # type: ignore[no-untyped-call]
266+
267+
if credentials is None:
268+
raise MissingApiKeyError(
269+
provider_name="AWS", env_var_name="AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY"
270+
)
271+
272+
# Try to get the credentials to ensure they're not expired/invalid
273+
_ = credentials.access_key
274+
_ = credentials.secret_key
275+
276+
except Exception as e:
277+
if isinstance(e, MissingApiKeyError):
278+
raise
279+
# If any other error occurs while checking credentials, treat as missing
280+
raise MissingApiKeyError(
281+
provider_name="AWS", env_var_name="AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY"
282+
) from e
264283

265284
def completion(
266285
self,
@@ -269,31 +288,35 @@ def completion(
269288
**kwargs: Any,
270289
) -> ChatCompletion:
271290
"""Create a chat completion using AWS Bedrock."""
291+
# Check credentials before creating client
292+
self._check_aws_credentials()
293+
272294
# Create client if not already created
273295
if self.client is None:
274296
try:
275-
self.client = boto3.client("bedrock-runtime", region_name=self.region_name)
297+
self.client = boto3.client("bedrock-runtime", region_name=self.region_name) # type: ignore[no-untyped-call]
276298
except Exception as e:
277299
raise RuntimeError(f"Failed to create AWS Bedrock client: {e}") from e
278-
300+
279301
system_message, formatted_messages = _convert_messages(messages)
280302
request_config = _convert_kwargs(kwargs)
281-
303+
282304
try:
305+
assert self.client is not None # For mypy
283306
response = self.client.converse(
284307
modelId=model,
285308
messages=formatted_messages,
286309
system=system_message,
287310
**request_config,
288311
)
289-
312+
290313
# Convert to OpenAI format
291314
return _convert_response(response)
292-
315+
293316
except botocore.exceptions.ClientError as e:
294317
if e.response["Error"]["Code"] == "ValidationException":
295318
error_message = e.response["Error"]["Message"]
296319
raise RuntimeError(f"AWS Bedrock validation error: {error_message}") from e
297320
raise RuntimeError(f"AWS Bedrock API error: {e}") from e
298321
except Exception as e:
299-
raise RuntimeError(f"AWS Bedrock API error: {e}") from e
322+
raise RuntimeError(f"AWS Bedrock API error: {e}") from e
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
from .azure import AzureProvider
22

3-
__all__ = ["AzureProvider"]
3+
__all__ = ["AzureProvider"]

0 commit comments

Comments
 (0)