Skip to content

Commit fa268d9

Browse files
authored
feat: Add AWS Bedrock LLM Support (#1173)
* feat: Add AWS Bedrock LLM Support This commit adds support for AWS Bedrock for LLM parsing. The implementation follows the approach of other LLM providers and uses the `BEDROCK_API_KEY` and `BEDROCK_REGION` environment variables for authentication. This resolves issue #1162. * feat: add bedrock llm extraction example * feat: address feedback on Bedrock LLM support PR - Update documentation to include Bedrock LLM integration. - Remove unnecessary test for LlmApiType.BEDROCK. - Add Bedrock to the existing manuals_llm_extraction example instead of creating a new one.
1 parent 6f91d97 commit fa268d9

File tree

5 files changed

+225
-1
lines changed

5 files changed

+225
-1
lines changed

docs/docs/ai/llm.mdx

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ We support the following types of LLM APIs:
2828
| [LiteLLM](#litellm) | `LlmApiType.LITE_LLM` |||
2929
| [OpenRouter](#openrouter) | `LlmApiType.OPEN_ROUTER` |||
3030
| [vLLM](#vllm) | `LlmApiType.VLLM` |||
31+
| [Bedrock](#bedrock) | `LlmApiType.BEDROCK` |||
3132

3233
## LLM Tasks
3334

@@ -440,3 +441,28 @@ cocoindex.LlmSpec(
440441

441442
</TabItem>
442443
</Tabs>
444+
445+
### Bedrock
446+
447+
To use the Bedrock API, you need to set up AWS credentials. You can do this by setting the following environment variables:
448+
449+
- `AWS_ACCESS_KEY_ID`
450+
- `AWS_SECRET_ACCESS_KEY`
451+
- `AWS_SESSION_TOKEN` (optional)
452+
453+
A spec for Bedrock looks like this:
454+
455+
<Tabs>
456+
<TabItem value="python" label="Python" default>
457+
458+
```python
459+
cocoindex.LlmSpec(
460+
api_type=cocoindex.LlmApiType.BEDROCK,
461+
model="us.anthropic.claude-3-5-haiku-20241022-v1:0",
462+
)
463+
```
464+
465+
</TabItem>
466+
</Tabs>
467+
468+
You can find the full list of models supported by Bedrock [here](https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html).

examples/manuals_llm_extraction/main.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,9 @@ def manual_extraction_flow(
118118
# Replace by this spec below, to use Anthropic API model
119119
# llm_spec=cocoindex.LlmSpec(
120120
# api_type=cocoindex.LlmApiType.ANTHROPIC, model="claude-3-5-sonnet-latest"),
121+
# Replace by this spec below, to use Bedrock API model
122+
# llm_spec=cocoindex.LlmSpec(
123+
# api_type=cocoindex.LlmApiType.BEDROCK, model="us.anthropic.claude-3-5-haiku-20241022-v1:0"),
121124
output_type=ModuleInfo,
122125
instruction="Please extract Python module information from the manual.",
123126
)

python/cocoindex/llm.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ class LlmApiType(Enum):
1414
OPEN_ROUTER = "OpenRouter"
1515
VOYAGE = "Voyage"
1616
VLLM = "Vllm"
17+
BEDROCK = "Bedrock"
1718

1819

1920
@dataclass

src/llm/bedrock.rs

Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
use crate::prelude::*;
2+
use base64::prelude::*;
3+
4+
use crate::llm::{
5+
LlmGenerateRequest, LlmGenerateResponse, LlmGenerationClient, OutputFormat,
6+
ToJsonSchemaOptions, detect_image_mime_type,
7+
};
8+
use anyhow::Context;
9+
use urlencoding::encode;
10+
11+
pub struct Client {
12+
api_key: String,
13+
region: String,
14+
client: reqwest::Client,
15+
}
16+
17+
impl Client {
18+
pub async fn new(address: Option<String>) -> Result<Self> {
19+
if address.is_some() {
20+
api_bail!("Bedrock doesn't support custom API address");
21+
}
22+
23+
let api_key = match std::env::var("BEDROCK_API_KEY") {
24+
Ok(val) => val,
25+
Err(_) => api_bail!("BEDROCK_API_KEY environment variable must be set"),
26+
};
27+
28+
// Default to us-east-1 if no region specified
29+
let region = std::env::var("BEDROCK_REGION").unwrap_or_else(|_| "us-east-1".to_string());
30+
31+
Ok(Self {
32+
api_key,
33+
region,
34+
client: reqwest::Client::new(),
35+
})
36+
}
37+
}
38+
39+
#[async_trait]
40+
impl LlmGenerationClient for Client {
41+
async fn generate<'req>(
42+
&self,
43+
request: LlmGenerateRequest<'req>,
44+
) -> Result<LlmGenerateResponse> {
45+
let mut user_content_parts: Vec<serde_json::Value> = Vec::new();
46+
47+
// Add image part if present
48+
if let Some(image_bytes) = &request.image {
49+
let base64_image = BASE64_STANDARD.encode(image_bytes.as_ref());
50+
let mime_type = detect_image_mime_type(image_bytes.as_ref())?;
51+
user_content_parts.push(serde_json::json!({
52+
"image": {
53+
"format": mime_type.split('/').nth(1).unwrap_or("png"),
54+
"source": {
55+
"bytes": base64_image,
56+
}
57+
}
58+
}));
59+
}
60+
61+
// Add text part
62+
user_content_parts.push(serde_json::json!({
63+
"text": request.user_prompt
64+
}));
65+
66+
let messages = vec![serde_json::json!({
67+
"role": "user",
68+
"content": user_content_parts
69+
})];
70+
71+
let mut payload = serde_json::json!({
72+
"messages": messages,
73+
"inferenceConfig": {
74+
"maxTokens": 4096
75+
}
76+
});
77+
78+
// Add system prompt if present
79+
if let Some(system) = request.system_prompt {
80+
payload["system"] = serde_json::json!([{
81+
"text": system
82+
}]);
83+
}
84+
85+
// Handle structured output using tool schema
86+
if let Some(OutputFormat::JsonSchema { schema, name }) = request.output_format.as_ref() {
87+
let schema_json = serde_json::to_value(schema)?;
88+
payload["toolConfig"] = serde_json::json!({
89+
"tools": [{
90+
"toolSpec": {
91+
"name": name,
92+
"description": format!("Extract structured data according to the schema"),
93+
"inputSchema": {
94+
"json": schema_json
95+
}
96+
}
97+
}]
98+
});
99+
}
100+
101+
// Construct the Bedrock Runtime API URL
102+
let url = format!(
103+
"https://bedrock-runtime.{}.amazonaws.com/model/{}/converse",
104+
self.region, request.model
105+
);
106+
107+
let encoded_api_key = encode(&self.api_key);
108+
109+
let resp = retryable::run(
110+
|| async {
111+
self.client
112+
.post(&url)
113+
.header(
114+
"Authorization",
115+
format!("Bearer {}", encoded_api_key.as_ref()),
116+
)
117+
.header("Content-Type", "application/json")
118+
.json(&payload)
119+
.send()
120+
.await?
121+
.error_for_status()
122+
},
123+
&retryable::HEAVY_LOADED_OPTIONS,
124+
)
125+
.await
126+
.context("Bedrock API error")?;
127+
128+
let resp_json: serde_json::Value = resp.json().await.context("Invalid JSON")?;
129+
130+
// Check for errors in the response
131+
if let Some(error) = resp_json.get("error") {
132+
bail!("Bedrock API error: {:?}", error);
133+
}
134+
135+
// Debug print full response (uncomment for debugging)
136+
// println!("Bedrock API full response: {resp_json:?}");
137+
138+
// Extract the response content
139+
let output = &resp_json["output"];
140+
let message = &output["message"];
141+
let content = &message["content"];
142+
143+
let text = if let Some(content_array) = content.as_array() {
144+
// Look for tool use first (structured output)
145+
let mut extracted_json: Option<serde_json::Value> = None;
146+
for item in content_array {
147+
if let Some(tool_use) = item.get("toolUse") {
148+
if let Some(input) = tool_use.get("input") {
149+
extracted_json = Some(input.clone());
150+
break;
151+
}
152+
}
153+
}
154+
155+
if let Some(json) = extracted_json {
156+
// Return the structured output as JSON
157+
serde_json::to_string(&json)?
158+
} else {
159+
// Fall back to text content
160+
let mut text_parts = Vec::new();
161+
for item in content_array {
162+
if let Some(text) = item.get("text") {
163+
if let Some(text_str) = text.as_str() {
164+
text_parts.push(text_str);
165+
}
166+
}
167+
}
168+
text_parts.join("")
169+
}
170+
} else {
171+
return Err(anyhow::anyhow!("No content found in Bedrock response"));
172+
};
173+
174+
Ok(LlmGenerateResponse { text })
175+
}
176+
177+
fn json_schema_options(&self) -> ToJsonSchemaOptions {
178+
ToJsonSchemaOptions {
179+
fields_always_required: false,
180+
supports_format: false,
181+
extract_descriptions: false,
182+
top_level_must_be_object: true,
183+
}
184+
}
185+
}

src/llm/mod.rs

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ pub enum LlmApiType {
1818
Voyage,
1919
Vllm,
2020
VertexAi,
21+
Bedrock,
2122
}
2223

2324
#[derive(Debug, Clone, Serialize, Deserialize)]
@@ -106,6 +107,7 @@ pub trait LlmEmbeddingClient: Send + Sync {
106107
}
107108

108109
mod anthropic;
110+
mod bedrock;
109111
mod gemini;
110112
mod litellm;
111113
mod ollama;
@@ -134,6 +136,9 @@ pub async fn new_llm_generation_client(
134136
LlmApiType::Anthropic => {
135137
Box::new(anthropic::Client::new(address).await?) as Box<dyn LlmGenerationClient>
136138
}
139+
LlmApiType::Bedrock => {
140+
Box::new(bedrock::Client::new(address).await?) as Box<dyn LlmGenerationClient>
141+
}
137142
LlmApiType::LiteLlm => {
138143
Box::new(litellm::Client::new_litellm(address).await?) as Box<dyn LlmGenerationClient>
139144
}
@@ -169,7 +174,11 @@ pub async fn new_llm_embedding_client(
169174
}
170175
LlmApiType::VertexAi => Box::new(gemini::VertexAiClient::new(address, api_config).await?)
171176
as Box<dyn LlmEmbeddingClient>,
172-
LlmApiType::OpenRouter | LlmApiType::LiteLlm | LlmApiType::Vllm | LlmApiType::Anthropic => {
177+
LlmApiType::OpenRouter
178+
| LlmApiType::LiteLlm
179+
| LlmApiType::Vllm
180+
| LlmApiType::Anthropic
181+
| LlmApiType::Bedrock => {
173182
api_bail!("Embedding is not supported for API type {:?}", api_type)
174183
}
175184
};

0 commit comments

Comments
 (0)