diff --git a/docs/docs/ai/llm.mdx b/docs/docs/ai/llm.mdx index 42fdfbb59..a01f79eb1 100644 --- a/docs/docs/ai/llm.mdx +++ b/docs/docs/ai/llm.mdx @@ -28,6 +28,7 @@ We support the following types of LLM APIs: | [LiteLLM](#litellm) | `LlmApiType.LITE_LLM` | ✅ | ❌ | | [OpenRouter](#openrouter) | `LlmApiType.OPEN_ROUTER` | ✅ | ❌ | | [vLLM](#vllm) | `LlmApiType.VLLM` | ✅ | ❌ | +| [Bedrock](#bedrock) | `LlmApiType.BEDROCK` | ✅ | ❌ | ## LLM Tasks @@ -440,3 +441,28 @@ cocoindex.LlmSpec( + +### Bedrock + +To use the Bedrock API, you need to set up AWS credentials. You can do this by setting the following environment variables: + +- `AWS_ACCESS_KEY_ID` +- `AWS_SECRET_ACCESS_KEY` +- `AWS_SESSION_TOKEN` (optional) + +A spec for Bedrock looks like this: + + + + +```python +cocoindex.LlmSpec( + api_type=cocoindex.LlmApiType.BEDROCK, + model="us.anthropic.claude-3-5-haiku-20241022-v1:0", +) +``` + + + + +You can find the full list of models supported by Bedrock [here](https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html). diff --git a/examples/manuals_llm_extraction/main.py b/examples/manuals_llm_extraction/main.py index e35a7700d..769272269 100644 --- a/examples/manuals_llm_extraction/main.py +++ b/examples/manuals_llm_extraction/main.py @@ -118,6 +118,9 @@ def manual_extraction_flow( # Replace by this spec below, to use Anthropic API model # llm_spec=cocoindex.LlmSpec( # api_type=cocoindex.LlmApiType.ANTHROPIC, model="claude-3-5-sonnet-latest"), + # Replace by this spec below, to use Bedrock API model + # llm_spec=cocoindex.LlmSpec( + # api_type=cocoindex.LlmApiType.BEDROCK, model="us.anthropic.claude-3-5-haiku-20241022-v1:0"), output_type=ModuleInfo, instruction="Please extract Python module information from the manual.", ) diff --git a/python/cocoindex/llm.py b/python/cocoindex/llm.py index 4dc98c429..3f12c90a6 100644 --- a/python/cocoindex/llm.py +++ b/python/cocoindex/llm.py @@ -14,6 +14,7 @@ class LlmApiType(Enum): OPEN_ROUTER = "OpenRouter" VOYAGE = "Voyage" VLLM = "Vllm" + BEDROCK = "Bedrock" @dataclass diff --git a/src/llm/bedrock.rs b/src/llm/bedrock.rs new file mode 100644 index 000000000..59c9cb4eb --- /dev/null +++ b/src/llm/bedrock.rs @@ -0,0 +1,185 @@ +use crate::prelude::*; +use base64::prelude::*; + +use crate::llm::{ + LlmGenerateRequest, LlmGenerateResponse, LlmGenerationClient, OutputFormat, + ToJsonSchemaOptions, detect_image_mime_type, +}; +use anyhow::Context; +use urlencoding::encode; + +pub struct Client { + api_key: String, + region: String, + client: reqwest::Client, +} + +impl Client { + pub async fn new(address: Option) -> Result { + if address.is_some() { + api_bail!("Bedrock doesn't support custom API address"); + } + + let api_key = match std::env::var("BEDROCK_API_KEY") { + Ok(val) => val, + Err(_) => api_bail!("BEDROCK_API_KEY environment variable must be set"), + }; + + // Default to us-east-1 if no region specified + let region = std::env::var("BEDROCK_REGION").unwrap_or_else(|_| "us-east-1".to_string()); + + Ok(Self { + api_key, + region, + client: reqwest::Client::new(), + }) + } +} + +#[async_trait] +impl LlmGenerationClient for Client { + async fn generate<'req>( + &self, + request: LlmGenerateRequest<'req>, + ) -> Result { + let mut user_content_parts: Vec = Vec::new(); + + // Add image part if present + if let Some(image_bytes) = &request.image { + let base64_image = BASE64_STANDARD.encode(image_bytes.as_ref()); + let mime_type = detect_image_mime_type(image_bytes.as_ref())?; + user_content_parts.push(serde_json::json!({ + "image": { + "format": mime_type.split('/').nth(1).unwrap_or("png"), + "source": { + "bytes": base64_image, + } + } + })); + } + + // Add text part + user_content_parts.push(serde_json::json!({ + "text": request.user_prompt + })); + + let messages = vec![serde_json::json!({ + "role": "user", + "content": user_content_parts + })]; + + let mut payload = serde_json::json!({ + "messages": messages, + "inferenceConfig": { + "maxTokens": 4096 + } + }); + + // Add system prompt if present + if let Some(system) = request.system_prompt { + payload["system"] = serde_json::json!([{ + "text": system + }]); + } + + // Handle structured output using tool schema + if let Some(OutputFormat::JsonSchema { schema, name }) = request.output_format.as_ref() { + let schema_json = serde_json::to_value(schema)?; + payload["toolConfig"] = serde_json::json!({ + "tools": [{ + "toolSpec": { + "name": name, + "description": format!("Extract structured data according to the schema"), + "inputSchema": { + "json": schema_json + } + } + }] + }); + } + + // Construct the Bedrock Runtime API URL + let url = format!( + "https://bedrock-runtime.{}.amazonaws.com/model/{}/converse", + self.region, request.model + ); + + let encoded_api_key = encode(&self.api_key); + + let resp = retryable::run( + || async { + self.client + .post(&url) + .header( + "Authorization", + format!("Bearer {}", encoded_api_key.as_ref()), + ) + .header("Content-Type", "application/json") + .json(&payload) + .send() + .await? + .error_for_status() + }, + &retryable::HEAVY_LOADED_OPTIONS, + ) + .await + .context("Bedrock API error")?; + + let resp_json: serde_json::Value = resp.json().await.context("Invalid JSON")?; + + // Check for errors in the response + if let Some(error) = resp_json.get("error") { + bail!("Bedrock API error: {:?}", error); + } + + // Debug print full response (uncomment for debugging) + // println!("Bedrock API full response: {resp_json:?}"); + + // Extract the response content + let output = &resp_json["output"]; + let message = &output["message"]; + let content = &message["content"]; + + let text = if let Some(content_array) = content.as_array() { + // Look for tool use first (structured output) + let mut extracted_json: Option = None; + for item in content_array { + if let Some(tool_use) = item.get("toolUse") { + if let Some(input) = tool_use.get("input") { + extracted_json = Some(input.clone()); + break; + } + } + } + + if let Some(json) = extracted_json { + // Return the structured output as JSON + serde_json::to_string(&json)? + } else { + // Fall back to text content + let mut text_parts = Vec::new(); + for item in content_array { + if let Some(text) = item.get("text") { + if let Some(text_str) = text.as_str() { + text_parts.push(text_str); + } + } + } + text_parts.join("") + } + } else { + return Err(anyhow::anyhow!("No content found in Bedrock response")); + }; + + Ok(LlmGenerateResponse { text }) + } + + fn json_schema_options(&self) -> ToJsonSchemaOptions { + ToJsonSchemaOptions { + fields_always_required: false, + supports_format: false, + extract_descriptions: false, + top_level_must_be_object: true, + } + } +} diff --git a/src/llm/mod.rs b/src/llm/mod.rs index 2145f2d12..12eda662a 100644 --- a/src/llm/mod.rs +++ b/src/llm/mod.rs @@ -18,6 +18,7 @@ pub enum LlmApiType { Voyage, Vllm, VertexAi, + Bedrock, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -106,6 +107,7 @@ pub trait LlmEmbeddingClient: Send + Sync { } mod anthropic; +mod bedrock; mod gemini; mod litellm; mod ollama; @@ -134,6 +136,9 @@ pub async fn new_llm_generation_client( LlmApiType::Anthropic => { Box::new(anthropic::Client::new(address).await?) as Box } + LlmApiType::Bedrock => { + Box::new(bedrock::Client::new(address).await?) as Box + } LlmApiType::LiteLlm => { Box::new(litellm::Client::new_litellm(address).await?) as Box } @@ -169,7 +174,11 @@ pub async fn new_llm_embedding_client( } LlmApiType::VertexAi => Box::new(gemini::VertexAiClient::new(address, api_config).await?) as Box, - LlmApiType::OpenRouter | LlmApiType::LiteLlm | LlmApiType::Vllm | LlmApiType::Anthropic => { + LlmApiType::OpenRouter + | LlmApiType::LiteLlm + | LlmApiType::Vllm + | LlmApiType::Anthropic + | LlmApiType::Bedrock => { api_bail!("Embedding is not supported for API type {:?}", api_type) } };