Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions docs/docs/ai/llm.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ We support the following types of LLM APIs:
| [LiteLLM](#litellm) | `LlmApiType.LITE_LLM` | ✅ | ❌ |
| [OpenRouter](#openrouter) | `LlmApiType.OPEN_ROUTER` | ✅ | ❌ |
| [vLLM](#vllm) | `LlmApiType.VLLM` | ✅ | ❌ |
| [Bedrock](#bedrock) | `LlmApiType.BEDROCK` | ✅ | ❌ |

## LLM Tasks

Expand Down Expand Up @@ -440,3 +441,28 @@ cocoindex.LlmSpec(

</TabItem>
</Tabs>

### Bedrock

To use the Bedrock API, you need to set up AWS credentials. You can do this by setting the following environment variables:

- `AWS_ACCESS_KEY_ID`
- `AWS_SECRET_ACCESS_KEY`
- `AWS_SESSION_TOKEN` (optional)

A spec for Bedrock looks like this:

<Tabs>
<TabItem value="python" label="Python" default>

```python
cocoindex.LlmSpec(
api_type=cocoindex.LlmApiType.BEDROCK,
model="us.anthropic.claude-3-5-haiku-20241022-v1:0",
)
```

</TabItem>
</Tabs>

You can find the full list of models supported by Bedrock [here](https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html).
3 changes: 3 additions & 0 deletions examples/manuals_llm_extraction/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,9 @@ def manual_extraction_flow(
# Replace by this spec below, to use Anthropic API model
# llm_spec=cocoindex.LlmSpec(
# api_type=cocoindex.LlmApiType.ANTHROPIC, model="claude-3-5-sonnet-latest"),
# Replace by this spec below, to use Bedrock API model
# llm_spec=cocoindex.LlmSpec(
# api_type=cocoindex.LlmApiType.BEDROCK, model="us.anthropic.claude-3-5-haiku-20241022-v1:0"),
output_type=ModuleInfo,
instruction="Please extract Python module information from the manual.",
)
Expand Down
1 change: 1 addition & 0 deletions python/cocoindex/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ class LlmApiType(Enum):
OPEN_ROUTER = "OpenRouter"
VOYAGE = "Voyage"
VLLM = "Vllm"
BEDROCK = "Bedrock"


@dataclass
Expand Down
185 changes: 185 additions & 0 deletions src/llm/bedrock.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
use crate::prelude::*;
use base64::prelude::*;

use crate::llm::{
LlmGenerateRequest, LlmGenerateResponse, LlmGenerationClient, OutputFormat,
ToJsonSchemaOptions, detect_image_mime_type,
};
use anyhow::Context;
use urlencoding::encode;

pub struct Client {
api_key: String,
region: String,
client: reqwest::Client,
}

impl Client {
pub async fn new(address: Option<String>) -> Result<Self> {
if address.is_some() {
api_bail!("Bedrock doesn't support custom API address");
}

let api_key = match std::env::var("BEDROCK_API_KEY") {
Ok(val) => val,
Err(_) => api_bail!("BEDROCK_API_KEY environment variable must be set"),
};

// Default to us-east-1 if no region specified
let region = std::env::var("BEDROCK_REGION").unwrap_or_else(|_| "us-east-1".to_string());

Ok(Self {
api_key,
region,
client: reqwest::Client::new(),
})
}
}

#[async_trait]
impl LlmGenerationClient for Client {
async fn generate<'req>(
&self,
request: LlmGenerateRequest<'req>,
) -> Result<LlmGenerateResponse> {
let mut user_content_parts: Vec<serde_json::Value> = Vec::new();

// Add image part if present
if let Some(image_bytes) = &request.image {
let base64_image = BASE64_STANDARD.encode(image_bytes.as_ref());
let mime_type = detect_image_mime_type(image_bytes.as_ref())?;
user_content_parts.push(serde_json::json!({
"image": {
"format": mime_type.split('/').nth(1).unwrap_or("png"),
"source": {
"bytes": base64_image,
}
}
}));
}

// Add text part
user_content_parts.push(serde_json::json!({
"text": request.user_prompt
}));

let messages = vec![serde_json::json!({
"role": "user",
"content": user_content_parts
})];

let mut payload = serde_json::json!({
"messages": messages,
"inferenceConfig": {
"maxTokens": 4096
}
});

// Add system prompt if present
if let Some(system) = request.system_prompt {
payload["system"] = serde_json::json!([{
"text": system
}]);
}

// Handle structured output using tool schema
if let Some(OutputFormat::JsonSchema { schema, name }) = request.output_format.as_ref() {
let schema_json = serde_json::to_value(schema)?;
payload["toolConfig"] = serde_json::json!({
"tools": [{
"toolSpec": {
"name": name,
"description": format!("Extract structured data according to the schema"),
"inputSchema": {
"json": schema_json
}
}
}]
});
}

// Construct the Bedrock Runtime API URL
let url = format!(
"https://bedrock-runtime.{}.amazonaws.com/model/{}/converse",
self.region, request.model
);

let encoded_api_key = encode(&self.api_key);

let resp = retryable::run(
|| async {
self.client
.post(&url)
.header(
"Authorization",
format!("Bearer {}", encoded_api_key.as_ref()),
)
.header("Content-Type", "application/json")
.json(&payload)
.send()
.await?
.error_for_status()
},
&retryable::HEAVY_LOADED_OPTIONS,
)
.await
.context("Bedrock API error")?;

let resp_json: serde_json::Value = resp.json().await.context("Invalid JSON")?;

// Check for errors in the response
if let Some(error) = resp_json.get("error") {
bail!("Bedrock API error: {:?}", error);
}

// Debug print full response (uncomment for debugging)
// println!("Bedrock API full response: {resp_json:?}");

// Extract the response content
let output = &resp_json["output"];
let message = &output["message"];
let content = &message["content"];

let text = if let Some(content_array) = content.as_array() {
// Look for tool use first (structured output)
let mut extracted_json: Option<serde_json::Value> = None;
for item in content_array {
if let Some(tool_use) = item.get("toolUse") {
if let Some(input) = tool_use.get("input") {
extracted_json = Some(input.clone());
break;
}
}
}

if let Some(json) = extracted_json {
// Return the structured output as JSON
serde_json::to_string(&json)?
} else {
// Fall back to text content
let mut text_parts = Vec::new();
for item in content_array {
if let Some(text) = item.get("text") {
if let Some(text_str) = text.as_str() {
text_parts.push(text_str);
}
}
}
text_parts.join("")
}
} else {
return Err(anyhow::anyhow!("No content found in Bedrock response"));
};

Ok(LlmGenerateResponse { text })
}

fn json_schema_options(&self) -> ToJsonSchemaOptions {
ToJsonSchemaOptions {
fields_always_required: false,
supports_format: false,
extract_descriptions: false,
top_level_must_be_object: true,
}
}
}
11 changes: 10 additions & 1 deletion src/llm/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ pub enum LlmApiType {
Voyage,
Vllm,
VertexAi,
Bedrock,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
Expand Down Expand Up @@ -106,6 +107,7 @@ pub trait LlmEmbeddingClient: Send + Sync {
}

mod anthropic;
mod bedrock;
mod gemini;
mod litellm;
mod ollama;
Expand Down Expand Up @@ -134,6 +136,9 @@ pub async fn new_llm_generation_client(
LlmApiType::Anthropic => {
Box::new(anthropic::Client::new(address).await?) as Box<dyn LlmGenerationClient>
}
LlmApiType::Bedrock => {
Box::new(bedrock::Client::new(address).await?) as Box<dyn LlmGenerationClient>
}
LlmApiType::LiteLlm => {
Box::new(litellm::Client::new_litellm(address).await?) as Box<dyn LlmGenerationClient>
}
Expand Down Expand Up @@ -169,7 +174,11 @@ pub async fn new_llm_embedding_client(
}
LlmApiType::VertexAi => Box::new(gemini::VertexAiClient::new(address, api_config).await?)
as Box<dyn LlmEmbeddingClient>,
LlmApiType::OpenRouter | LlmApiType::LiteLlm | LlmApiType::Vllm | LlmApiType::Anthropic => {
LlmApiType::OpenRouter
| LlmApiType::LiteLlm
| LlmApiType::Vllm
| LlmApiType::Anthropic
| LlmApiType::Bedrock => {
api_bail!("Embedding is not supported for API type {:?}", api_type)
}
};
Expand Down