Skip to content

Commit a619ebd

Browse files
feature:AWS Bedrock Initial Commit (#99)
1 parent 0e8fb80 commit a619ebd

File tree

7 files changed

+435
-9
lines changed

7 files changed

+435
-9
lines changed

api/README.md

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,12 @@ OPENAI_API_KEY=your_openai_api_key # Required for embeddings and OpenAI m
3131
# Optional API Keys
3232
OPENROUTER_API_KEY=your_openrouter_api_key # Required only if using OpenRouter models
3333
34+
# AWS Bedrock Configuration
35+
AWS_ACCESS_KEY_ID=your_aws_access_key_id # Required for AWS Bedrock models
36+
AWS_SECRET_ACCESS_KEY=your_aws_secret_key # Required for AWS Bedrock models
37+
AWS_REGION=us-east-1 # Optional, defaults to us-east-1
38+
AWS_ROLE_ARN=your_aws_role_arn # Optional, for role-based authentication
39+
3440
# OpenAI API Configuration
3541
OPENAI_BASE_URL=https://custom-api-endpoint.com/v1 # Optional, for custom OpenAI API endpoints
3642
@@ -47,6 +53,7 @@ If you're not using Ollama mode, you need to configure an OpenAI API key for emb
4753
> - Get a Google API key from [Google AI Studio](https://makersuite.google.com/app/apikey)
4854
> - Get an OpenAI API key from [OpenAI Platform](https://platform.openai.com/api-keys)
4955
> - Get an OpenRouter API key from [OpenRouter](https://openrouter.ai/keys)
56+
> - Get AWS credentials from [AWS IAM Console](https://console.aws.amazon.com/iam/)
5057
5158
#### Advanced Environment Configuration
5259

@@ -56,6 +63,7 @@ DeepWiki supports multiple LLM providers. The environment variables above are re
5663
- **Google Gemini**: Requires `GOOGLE_API_KEY`
5764
- **OpenAI**: Requires `OPENAI_API_KEY`
5865
- **OpenRouter**: Requires `OPENROUTER_API_KEY`
66+
- **AWS Bedrock**: Requires `AWS_ACCESS_KEY_ID` and `AWS_SECRET_ACCESS_KEY`
5967
- **Ollama**: No API key required (runs locally)
6068

6169
##### Custom OpenAI API Endpoints
@@ -75,7 +83,7 @@ DeepWiki now uses JSON configuration files to manage various system components i
7583

7684
1. **`generator.json`**: Configuration for text generation models
7785
- Located in `api/config/` by default
78-
- Defines available model providers (Google, OpenAI, OpenRouter, Ollama)
86+
- Defines available model providers (Google, OpenAI, OpenRouter, AWS Bedrock, Ollama)
7987
- Specifies default and available models for each provider
8088
- Contains model-specific parameters like temperature and top_p
8189

api/bedrock_client.py

Lines changed: 311 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,311 @@
1+
"""AWS Bedrock ModelClient integration."""
2+
3+
import os
4+
import json
5+
import logging
6+
import boto3
7+
import botocore
8+
import backoff
9+
from typing import Dict, Any, Optional, List, Generator, Union, AsyncGenerator
10+
11+
from adalflow.core.model_client import ModelClient
12+
from adalflow.core.types import ModelType, GeneratorOutput
13+
14+
log = logging.getLogger(__name__)
15+
16+
class BedrockClient(ModelClient):
17+
__doc__ = r"""A component wrapper for the AWS Bedrock API client.
18+
19+
AWS Bedrock provides a unified API that gives access to various foundation models
20+
including Amazon's own models and third-party models like Anthropic Claude.
21+
22+
Example:
23+
```python
24+
from api.bedrock_client import BedrockClient
25+
26+
client = BedrockClient()
27+
generator = adal.Generator(
28+
model_client=client,
29+
model_kwargs={"model": "anthropic.claude-3-sonnet-20240229-v1:0"}
30+
)
31+
```
32+
"""
33+
34+
def __init__(
35+
self,
36+
aws_access_key_id: Optional[str] = None,
37+
aws_secret_access_key: Optional[str] = None,
38+
aws_region: Optional[str] = None,
39+
aws_role_arn: Optional[str] = None,
40+
*args,
41+
**kwargs
42+
) -> None:
43+
"""Initialize the AWS Bedrock client.
44+
45+
Args:
46+
aws_access_key_id: AWS access key ID. If not provided, will use environment variable AWS_ACCESS_KEY_ID.
47+
aws_secret_access_key: AWS secret access key. If not provided, will use environment variable AWS_SECRET_ACCESS_KEY.
48+
aws_region: AWS region. If not provided, will use environment variable AWS_REGION.
49+
aws_role_arn: AWS IAM role ARN for role-based authentication. If not provided, will use environment variable AWS_ROLE_ARN.
50+
"""
51+
super().__init__(*args, **kwargs)
52+
self.aws_access_key_id = aws_access_key_id or os.environ.get("AWS_ACCESS_KEY_ID")
53+
self.aws_secret_access_key = aws_secret_access_key or os.environ.get("AWS_SECRET_ACCESS_KEY")
54+
self.aws_region = aws_region or os.environ.get("AWS_REGION", "us-east-1")
55+
self.aws_role_arn = aws_role_arn or os.environ.get("AWS_ROLE_ARN")
56+
57+
self.sync_client = self.init_sync_client()
58+
self.async_client = None # Initialize async client only when needed
59+
60+
def init_sync_client(self):
61+
"""Initialize the synchronous AWS Bedrock client."""
62+
try:
63+
# Create a session with the provided credentials
64+
session = boto3.Session(
65+
aws_access_key_id=self.aws_access_key_id,
66+
aws_secret_access_key=self.aws_secret_access_key,
67+
region_name=self.aws_region
68+
)
69+
70+
# If a role ARN is provided, assume that role
71+
if self.aws_role_arn:
72+
sts_client = session.client('sts')
73+
assumed_role = sts_client.assume_role(
74+
RoleArn=self.aws_role_arn,
75+
RoleSessionName="DeepWikiBedrockSession"
76+
)
77+
credentials = assumed_role['Credentials']
78+
79+
# Create a new session with the assumed role credentials
80+
session = boto3.Session(
81+
aws_access_key_id=credentials['AccessKeyId'],
82+
aws_secret_access_key=credentials['SecretAccessKey'],
83+
aws_session_token=credentials['SessionToken'],
84+
region_name=self.aws_region
85+
)
86+
87+
# Create the Bedrock client
88+
bedrock_runtime = session.client(
89+
service_name='bedrock-runtime',
90+
region_name=self.aws_region
91+
)
92+
93+
return bedrock_runtime
94+
95+
except Exception as e:
96+
log.error(f"Error initializing AWS Bedrock client: {str(e)}")
97+
# Return None to indicate initialization failure
98+
return None
99+
100+
def init_async_client(self):
101+
"""Initialize the asynchronous AWS Bedrock client.
102+
103+
Note: boto3 doesn't have native async support, so we'll use the sync client
104+
in async methods and handle async behavior at a higher level.
105+
"""
106+
# For now, just return the sync client
107+
return self.sync_client
108+
109+
def _get_model_provider(self, model_id: str) -> str:
110+
"""Extract the provider from the model ID.
111+
112+
Args:
113+
model_id: The model ID, e.g., "anthropic.claude-3-sonnet-20240229-v1:0"
114+
115+
Returns:
116+
The provider name, e.g., "anthropic"
117+
"""
118+
if "." in model_id:
119+
return model_id.split(".")[0]
120+
return "amazon" # Default provider
121+
122+
def _format_prompt_for_provider(self, provider: str, prompt: str, messages=None) -> Dict[str, Any]:
123+
"""Format the prompt according to the provider's requirements.
124+
125+
Args:
126+
provider: The provider name, e.g., "anthropic"
127+
prompt: The prompt text
128+
messages: Optional list of messages for chat models
129+
130+
Returns:
131+
A dictionary with the formatted prompt
132+
"""
133+
if provider == "anthropic":
134+
# Format for Claude models
135+
if messages:
136+
# Format as a conversation
137+
formatted_messages = []
138+
for msg in messages:
139+
role = "user" if msg.get("role") == "user" else "assistant"
140+
formatted_messages.append({
141+
"role": role,
142+
"content": [{"type": "text", "text": msg.get("content", "")}]
143+
})
144+
return {
145+
"anthropic_version": "bedrock-2023-05-31",
146+
"messages": formatted_messages,
147+
"max_tokens": 4096
148+
}
149+
else:
150+
# Format as a single prompt
151+
return {
152+
"anthropic_version": "bedrock-2023-05-31",
153+
"messages": [
154+
{"role": "user", "content": [{"type": "text", "text": prompt}]}
155+
],
156+
"max_tokens": 4096
157+
}
158+
elif provider == "amazon":
159+
# Format for Amazon Titan models
160+
return {
161+
"inputText": prompt,
162+
"textGenerationConfig": {
163+
"maxTokenCount": 4096,
164+
"stopSequences": [],
165+
"temperature": 0.7,
166+
"topP": 0.8
167+
}
168+
}
169+
elif provider == "cohere":
170+
# Format for Cohere models
171+
return {
172+
"prompt": prompt,
173+
"max_tokens": 4096,
174+
"temperature": 0.7,
175+
"p": 0.8
176+
}
177+
elif provider == "ai21":
178+
# Format for AI21 models
179+
return {
180+
"prompt": prompt,
181+
"maxTokens": 4096,
182+
"temperature": 0.7,
183+
"topP": 0.8
184+
}
185+
else:
186+
# Default format
187+
return {"prompt": prompt}
188+
189+
def _extract_response_text(self, provider: str, response: Dict[str, Any]) -> str:
190+
"""Extract the generated text from the response.
191+
192+
Args:
193+
provider: The provider name, e.g., "anthropic"
194+
response: The response from the Bedrock API
195+
196+
Returns:
197+
The generated text
198+
"""
199+
if provider == "anthropic":
200+
return response.get("content", [{}])[0].get("text", "")
201+
elif provider == "amazon":
202+
return response.get("results", [{}])[0].get("outputText", "")
203+
elif provider == "cohere":
204+
return response.get("generations", [{}])[0].get("text", "")
205+
elif provider == "ai21":
206+
return response.get("completions", [{}])[0].get("data", {}).get("text", "")
207+
else:
208+
# Try to extract text from the response
209+
if isinstance(response, dict):
210+
for key in ["text", "content", "output", "completion"]:
211+
if key in response:
212+
return response[key]
213+
return str(response)
214+
215+
@backoff.on_exception(
216+
backoff.expo,
217+
(botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError),
218+
max_time=5,
219+
)
220+
def call(self, api_kwargs: Dict = None, model_type: ModelType = None) -> Any:
221+
"""Make a synchronous call to the AWS Bedrock API."""
222+
api_kwargs = api_kwargs or {}
223+
224+
# Check if client is initialized
225+
if not self.sync_client:
226+
error_msg = "AWS Bedrock client not initialized. Check your AWS credentials and region."
227+
log.error(error_msg)
228+
return error_msg
229+
230+
if model_type == ModelType.LLM:
231+
model_id = api_kwargs.get("model", "anthropic.claude-3-sonnet-20240229-v1:0")
232+
provider = self._get_model_provider(model_id)
233+
234+
# Get the prompt from api_kwargs
235+
prompt = api_kwargs.get("input", "")
236+
messages = api_kwargs.get("messages")
237+
238+
# Format the prompt according to the provider
239+
request_body = self._format_prompt_for_provider(provider, prompt, messages)
240+
241+
# Add model parameters if provided
242+
if "temperature" in api_kwargs:
243+
if provider == "anthropic":
244+
request_body["temperature"] = api_kwargs["temperature"]
245+
elif provider == "amazon":
246+
request_body["textGenerationConfig"]["temperature"] = api_kwargs["temperature"]
247+
elif provider == "cohere":
248+
request_body["temperature"] = api_kwargs["temperature"]
249+
elif provider == "ai21":
250+
request_body["temperature"] = api_kwargs["temperature"]
251+
252+
if "top_p" in api_kwargs:
253+
if provider == "anthropic":
254+
request_body["top_p"] = api_kwargs["top_p"]
255+
elif provider == "amazon":
256+
request_body["textGenerationConfig"]["topP"] = api_kwargs["top_p"]
257+
elif provider == "cohere":
258+
request_body["p"] = api_kwargs["top_p"]
259+
elif provider == "ai21":
260+
request_body["topP"] = api_kwargs["top_p"]
261+
262+
# Convert request body to JSON
263+
body = json.dumps(request_body)
264+
265+
try:
266+
# Make the API call
267+
response = self.sync_client.invoke_model(
268+
modelId=model_id,
269+
body=body
270+
)
271+
272+
# Parse the response
273+
response_body = json.loads(response["body"].read())
274+
275+
# Extract the generated text
276+
generated_text = self._extract_response_text(provider, response_body)
277+
278+
return generated_text
279+
280+
except Exception as e:
281+
log.error(f"Error calling AWS Bedrock API: {str(e)}")
282+
return f"Error: {str(e)}"
283+
else:
284+
raise ValueError(f"Model type {model_type} is not supported by AWS Bedrock client")
285+
286+
async def acall(self, api_kwargs: Dict = None, model_type: ModelType = None) -> Any:
287+
"""Make an asynchronous call to the AWS Bedrock API."""
288+
# For now, just call the sync method
289+
# In a real implementation, you would use an async library or run the sync method in a thread pool
290+
return self.call(api_kwargs, model_type)
291+
292+
def convert_inputs_to_api_kwargs(
293+
self, input: Any = None, model_kwargs: Dict = None, model_type: ModelType = None
294+
) -> Dict:
295+
"""Convert inputs to API kwargs for AWS Bedrock."""
296+
model_kwargs = model_kwargs or {}
297+
api_kwargs = {}
298+
299+
if model_type == ModelType.LLM:
300+
api_kwargs["model"] = model_kwargs.get("model", "anthropic.claude-3-sonnet-20240229-v1:0")
301+
api_kwargs["input"] = input
302+
303+
# Add model parameters
304+
if "temperature" in model_kwargs:
305+
api_kwargs["temperature"] = model_kwargs["temperature"]
306+
if "top_p" in model_kwargs:
307+
api_kwargs["top_p"] = model_kwargs["top_p"]
308+
309+
return api_kwargs
310+
else:
311+
raise ValueError(f"Model type {model_type} is not supported by AWS Bedrock client")

0 commit comments

Comments
 (0)