Skip to content

Commit 97f1f30

Browse files
committed
Additional providers support
1 parent c0578a0 commit 97f1f30

File tree

18 files changed

+1330
-1
lines changed

18 files changed

+1330
-1
lines changed

docs/images/any-llm_favicon.png

-43.7 KB
Loading

docs/providers.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,13 @@
99
- [Ollama](https://github.com/ollama/ollama)
1010
- [DeepSeek](https://platform.deepseek.com/)
1111
- [HuggingFace](https://huggingface.co/inference-endpoints)
12+
- [Cohere](https://cohere.com/api)
13+
- [Cerebras](https://docs.cerebras.ai/)
14+
- [Fireworks](https://fireworks.ai/api)
15+
- [Groq](https://groq.com/api)
16+
- [AWS Bedrock](https://aws.amazon.com/bedrock/)
17+
- [Azure OpenAI](https://azure.microsoft.com/en-us/products/ai-services/openai-service)
18+
- [IBM Watsonx](https://www.ibm.com/watsonx)
1219
- [Inception Labs](https://inceptionlabs.ai/)
1320
- [Moonshot AI](https://platform.moonshot.ai/)
1421
- [Nebius AI Studio](https://studio.nebius.ai/)

pyproject.toml

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ dependencies = [
1818
[project.optional-dependencies]
1919

2020
all = [
21-
"any-llm-sdk[mistral,anthropic,huggingface,google]"
21+
"any-llm-sdk[mistral,anthropic,huggingface,google,cohere,cerebras,fireworks,groq,aws,azure,watsonx]"
2222
]
2323

2424
mistral = [
@@ -37,6 +37,34 @@ huggingface = [
3737
"huggingface-hub",
3838
]
3939

40+
cohere = [
41+
"cohere",
42+
]
43+
44+
cerebras = [
45+
"cerebras-cloud-sdk",
46+
]
47+
48+
fireworks = [
49+
"httpx",
50+
]
51+
52+
groq = [
53+
"groq",
54+
]
55+
56+
aws = [
57+
"boto3",
58+
]
59+
60+
azure = [
61+
"httpx",
62+
]
63+
64+
watsonx = [
65+
"ibm-watsonx-ai",
66+
]
67+
4068
[project.urls]
4169
Documentation = "https://mozilla-ai.github.io/any-llm/"
4270
Issues = "https://github.com/mozilla-ai/any-llm/issues"
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .aws import AwsProvider
2+
3+
__all__ = ["AwsProvider"]

src/any_llm/providers/aws/aws.py

Lines changed: 299 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,299 @@
1+
import os
2+
import json
3+
from typing import Any, Optional
4+
5+
try:
6+
import boto3
7+
import botocore
8+
except ImportError:
9+
msg = "boto3 is not installed. Please install it with `pip install any-llm-sdk[aws]`"
10+
raise ImportError(msg)
11+
12+
from openai.types.chat.chat_completion import ChatCompletion, Choice
13+
from openai.types.completion_usage import CompletionUsage
14+
from openai.types.chat.chat_completion_message import ChatCompletionMessage
15+
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall, Function
16+
from any_llm.provider import Provider, ApiConfig
17+
from any_llm.exceptions import MissingApiKeyError
18+
19+
20+
INFERENCE_PARAMETERS = ["maxTokens", "temperature", "topP", "stopSequences"]
21+
22+
23+
def _convert_kwargs(kwargs: dict[str, Any]) -> dict[str, Any]:
24+
"""Format the kwargs for AWS Bedrock."""
25+
kwargs = kwargs.copy()
26+
27+
# Convert tools and remove from kwargs
28+
tool_config = _convert_tool_spec(kwargs)
29+
kwargs.pop("tools", None) # Remove tools from kwargs if present
30+
31+
# Prepare inference config
32+
inference_config = {
33+
key: kwargs[key]
34+
for key in INFERENCE_PARAMETERS
35+
if key in kwargs
36+
}
37+
38+
additional_fields = {
39+
key: value
40+
for key, value in kwargs.items()
41+
if key not in INFERENCE_PARAMETERS
42+
}
43+
44+
request_config = {
45+
"inferenceConfig": inference_config,
46+
"additionalModelRequestFields": additional_fields,
47+
}
48+
49+
if tool_config is not None:
50+
request_config["toolConfig"] = tool_config
51+
52+
return request_config
53+
54+
55+
def _convert_tool_spec(kwargs: dict[str, Any]) -> Optional[dict[str, Any]]:
56+
"""Convert tool specifications to Bedrock format."""
57+
if "tools" not in kwargs:
58+
return None
59+
60+
tool_config = {
61+
"tools": [
62+
{
63+
"toolSpec": {
64+
"name": tool["function"]["name"],
65+
"description": tool["function"].get("description", " "),
66+
"inputSchema": {"json": tool["function"]["parameters"]},
67+
}
68+
}
69+
for tool in kwargs["tools"]
70+
]
71+
}
72+
return tool_config
73+
74+
75+
def _convert_messages(messages: list[dict[str, Any]]) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
76+
"""Convert messages to AWS Bedrock format."""
77+
# Handle system message
78+
system_message = []
79+
if messages and messages[0]["role"] == "system":
80+
system_message = [{"text": messages[0]["content"]}]
81+
messages = messages[1:]
82+
83+
formatted_messages = []
84+
for message in messages:
85+
# Skip any additional system messages
86+
if message["role"] == "system":
87+
continue
88+
89+
if message["role"] == "tool":
90+
bedrock_message = _convert_tool_result(message)
91+
if bedrock_message:
92+
formatted_messages.append(bedrock_message)
93+
elif message["role"] == "assistant":
94+
bedrock_message = _convert_assistant(message)
95+
if bedrock_message:
96+
formatted_messages.append(bedrock_message)
97+
else: # user messages
98+
formatted_messages.append({
99+
"role": message["role"],
100+
"content": [{"text": message["content"]}],
101+
})
102+
103+
return system_message, formatted_messages
104+
105+
106+
def _convert_tool_result(message: dict[str, Any]) -> Optional[dict[str, Any]]:
107+
"""Convert OpenAI tool result format to AWS Bedrock format."""
108+
if message["role"] != "tool" or "content" not in message:
109+
return None
110+
111+
tool_call_id = message.get("tool_call_id")
112+
if not tool_call_id:
113+
raise RuntimeError("Tool result message must include tool_call_id")
114+
115+
try:
116+
content_json = json.loads(message["content"])
117+
content = [{"json": content_json}]
118+
except json.JSONDecodeError:
119+
content = [{"text": message["content"]}]
120+
121+
return {
122+
"role": "user",
123+
"content": [
124+
{"toolResult": {"toolUseId": tool_call_id, "content": content}}
125+
],
126+
}
127+
128+
129+
def _convert_assistant(message: dict[str, Any]) -> Optional[dict[str, Any]]:
130+
"""Convert OpenAI assistant format to AWS Bedrock format."""
131+
if message["role"] != "assistant":
132+
return None
133+
134+
content = []
135+
136+
if message.get("content"):
137+
content.append({"text": message["content"]})
138+
139+
if message.get("tool_calls"):
140+
for tool_call in message["tool_calls"]:
141+
if tool_call["type"] == "function":
142+
try:
143+
input_json = json.loads(tool_call["function"]["arguments"])
144+
except json.JSONDecodeError:
145+
input_json = tool_call["function"]["arguments"]
146+
147+
content.append({
148+
"toolUse": {
149+
"toolUseId": tool_call["id"],
150+
"name": tool_call["function"]["name"],
151+
"input": input_json,
152+
}
153+
})
154+
155+
return {"role": "assistant", "content": content} if content else None
156+
157+
158+
def _convert_response(response: dict[str, Any]) -> ChatCompletion:
159+
"""Convert AWS Bedrock response to OpenAI ChatCompletion format."""
160+
# Check if the model is requesting tool use
161+
if response.get("stopReason") == "tool_use":
162+
tool_calls = []
163+
for content in response["output"]["message"]["content"]:
164+
if "toolUse" in content:
165+
tool = content["toolUse"]
166+
tool_calls.append(
167+
ChatCompletionMessageToolCall(
168+
id=tool["toolUseId"],
169+
type="function",
170+
function=Function(
171+
name=tool["name"],
172+
arguments=json.dumps(tool["input"]),
173+
),
174+
)
175+
)
176+
177+
if tool_calls:
178+
message = ChatCompletionMessage(
179+
content=None,
180+
role="assistant",
181+
tool_calls=tool_calls,
182+
)
183+
184+
choice = Choice(
185+
finish_reason="tool_calls", # type: ignore
186+
index=0,
187+
message=message,
188+
)
189+
190+
usage = None
191+
if "usage" in response:
192+
usage_data = response["usage"]
193+
usage = CompletionUsage(
194+
completion_tokens=usage_data.get("outputTokens", 0),
195+
prompt_tokens=usage_data.get("inputTokens", 0),
196+
total_tokens=usage_data.get("totalTokens", 0),
197+
)
198+
199+
return ChatCompletion(
200+
id=response.get("id", ""),
201+
model=response.get("model", ""),
202+
object="chat.completion",
203+
created=response.get("created", 0),
204+
choices=[choice],
205+
usage=usage,
206+
)
207+
208+
# Handle regular text response
209+
content = response["output"]["message"]["content"][0]["text"]
210+
211+
# Map Bedrock stopReason to OpenAI finish_reason
212+
stop_reason = response.get("stopReason")
213+
if stop_reason == "complete":
214+
finish_reason = "stop"
215+
elif stop_reason == "max_tokens":
216+
finish_reason = "length"
217+
else:
218+
finish_reason = stop_reason or "stop"
219+
220+
message = ChatCompletionMessage(
221+
content=content,
222+
role="assistant",
223+
tool_calls=None,
224+
)
225+
226+
choice = Choice(
227+
finish_reason=finish_reason, # type: ignore
228+
index=0,
229+
message=message,
230+
)
231+
232+
usage = None
233+
if "usage" in response:
234+
usage_data = response["usage"]
235+
usage = CompletionUsage(
236+
completion_tokens=usage_data.get("outputTokens", 0),
237+
prompt_tokens=usage_data.get("inputTokens", 0),
238+
total_tokens=usage_data.get("totalTokens", 0),
239+
)
240+
241+
return ChatCompletion(
242+
id=response.get("id", ""),
243+
model=response.get("model", ""),
244+
object="chat.completion",
245+
created=response.get("created", 0),
246+
choices=[choice],
247+
usage=usage,
248+
)
249+
250+
251+
class AwsProvider(Provider):
252+
"""AWS Bedrock Provider using boto3."""
253+
254+
def __init__(self, config: ApiConfig) -> None:
255+
"""Initialize AWS Bedrock provider."""
256+
# AWS uses region from environment variables or default
257+
self.region_name = os.getenv("AWS_REGION", "us-east-1")
258+
259+
# Store config for later use
260+
self.config = config
261+
262+
# Don't create client during init to avoid test failures
263+
self.client = None
264+
265+
def completion(
266+
self,
267+
model: str,
268+
messages: list[dict[str, Any]],
269+
**kwargs: Any,
270+
) -> ChatCompletion:
271+
"""Create a chat completion using AWS Bedrock."""
272+
# Create client if not already created
273+
if self.client is None:
274+
try:
275+
self.client = boto3.client("bedrock-runtime", region_name=self.region_name)
276+
except Exception as e:
277+
raise RuntimeError(f"Failed to create AWS Bedrock client: {e}") from e
278+
279+
system_message, formatted_messages = _convert_messages(messages)
280+
request_config = _convert_kwargs(kwargs)
281+
282+
try:
283+
response = self.client.converse(
284+
modelId=model,
285+
messages=formatted_messages,
286+
system=system_message,
287+
**request_config,
288+
)
289+
290+
# Convert to OpenAI format
291+
return _convert_response(response)
292+
293+
except botocore.exceptions.ClientError as e:
294+
if e.response["Error"]["Code"] == "ValidationException":
295+
error_message = e.response["Error"]["Message"]
296+
raise RuntimeError(f"AWS Bedrock validation error: {error_message}") from e
297+
raise RuntimeError(f"AWS Bedrock API error: {e}") from e
298+
except Exception as e:
299+
raise RuntimeError(f"AWS Bedrock API error: {e}") from e
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .azure import AzureProvider
2+
3+
__all__ = ["AzureProvider"]

0 commit comments

Comments
 (0)