Skip to content

Commit 8e96ee6

Browse files
kausmeowsdirkbrnd
authored andcommitted
add aws bedrock embedder (agno-agi#3075)
## Summary AWS Bedrock embedder- Cohere Embed-multilingual v3 ![image](https://github.com/user-attachments/assets/751c5ded-1519-40a2-9ebd-5291be855bbf) (If applicable, issue number: #____) ## Type of change - [ ] Bug fix - [x] New feature - [ ] Breaking change - [ ] Improvement - [ ] Model update - [ ] Other: --- ## Checklist - [ ] Code complies with style guidelines - [ ] Ran format/validation scripts (`./scripts/format.sh` and `./scripts/validate.sh`) - [ ] Self-review completed - [ ] Documentation updated (comments, docstrings) - [ ] Examples and guides: Relevant cookbook examples have been included or updated (if applicable) - [ ] Tested in clean environment - [ ] Tests added/updated (if applicable) --- ## Additional Notes Add any important context (deployment instructions, screenshots, security considerations, etc.) --------- Co-authored-by: Dirk Brand <[email protected]>
1 parent d69b270 commit 8e96ee6

File tree

11 files changed

+258
-12
lines changed

11 files changed

+258
-12
lines changed

cookbook/agent_concepts/context/__init__.py

Whitespace-only changes.
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from agno.document.reader.pdf_reader import PDFUrlReader
2+
from agno.embedder.aws_bedrock import AwsBedrockEmbedder
3+
from agno.knowledge.pdf_url import PDFUrlKnowledgeBase
4+
from agno.vectordb.pgvector import PgVector
5+
6+
embeddings = AwsBedrockEmbedder().get_embedding(
7+
"The quick brown fox jumps over the lazy dog."
8+
)
9+
# Print the embeddings and their dimensions
10+
print(f"Embeddings: {embeddings[:5]}")
11+
print(f"Dimensions: {len(embeddings)}")
12+
13+
# Example usage:
14+
knowledge_base = PDFUrlKnowledgeBase(
15+
urls=["https://agno-public.s3.amazonaws.com/recipes/ThaiRecipes.pdf"],
16+
reader=PDFUrlReader(
17+
chunk_size=2048
18+
), # Required because cohere has a fixed size of 2048
19+
vector_db=PgVector(
20+
table_name="recipes",
21+
db_url="postgresql+psycopg://ai:ai@localhost:5532/ai",
22+
embedder=AwsBedrockEmbedder(),
23+
),
24+
)
25+
knowledge_base.load(recreate=False)

cookbook/agent_concepts/state/__init__.py

Whitespace-only changes.

libs/agno/agno/document/chunking/fixed.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ def chunk(self, document: Document) -> List[Document]:
2222
chunked_documents: List[Document] = []
2323
chunk_number = 1
2424
chunk_meta_data = document.meta_data
25-
2625
start = 0
2726
while start + self.overlap < content_length:
2827
end = min(start + self.chunk_size, content_length)
@@ -55,5 +54,4 @@ def chunk(self, document: Document) -> List[Document]:
5554
)
5655
chunk_number += 1
5756
start = end - self.overlap
58-
5957
return chunked_documents

libs/agno/agno/document/reader/base.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import asyncio
22
from dataclasses import dataclass, field
3-
from typing import Any, List
3+
from typing import Any, List, Optional
44

55
from agno.document.base import Document
66
from agno.document.chunking.fixed import FixedSizeChunking
@@ -12,9 +12,13 @@ class Reader:
1212
"""Base class for reading documents"""
1313

1414
chunk: bool = True
15-
chunk_size: int = 3000
15+
chunk_size: int = 5000
1616
separators: List[str] = field(default_factory=lambda: ["\n", "\n\n", "\r", "\r\n", "\n\r", "\t", " ", " "])
17-
chunking_strategy: ChunkingStrategy = field(default_factory=FixedSizeChunking)
17+
chunking_strategy: Optional[ChunkingStrategy] = None
18+
19+
def __init__(self, chunk_size: int = 5000, chunking_strategy: Optional[ChunkingStrategy] = None) -> None:
20+
self.chunk_size = chunk_size
21+
self.chunking_strategy = chunking_strategy or FixedSizeChunking(chunk_size=self.chunk_size)
1822

1923
def read(self, obj: Any) -> List[Document]:
2024
raise NotImplementedError
@@ -23,7 +27,7 @@ async def async_read(self, obj: Any) -> List[Document]:
2327
raise NotImplementedError
2428

2529
def chunk_document(self, document: Document) -> List[Document]:
26-
return self.chunking_strategy.chunk(document)
30+
return self.chunking_strategy.chunk(document) # type: ignore
2731

2832
async def chunk_documents_async(self, documents: List[Document]) -> List[Document]:
2933
"""

libs/agno/agno/document/reader/firecrawl_reader.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,18 @@
1111
except ImportError:
1212
raise ImportError("The `firecrawl` package is not installed. Please install it via `pip install firecrawl-py`.")
1313

14-
1514
@dataclass
1615
class FirecrawlReader(Reader):
1716
api_key: Optional[str] = None
1817
params: Optional[Dict] = None
1918
mode: Literal["scrape", "crawl"] = "scrape"
19+
20+
def __init__(self, api_key: Optional[str] = None, params: Optional[Dict] = None, mode: Literal["scrape", "crawl"] = "scrape", *args, **kwargs) -> None:
21+
super().__init__(*args, **kwargs)
22+
self.api_key = api_key
23+
self.params = params
24+
self.mode = mode
25+
2026

2127
def scrape(self, url: str) -> List[Document]:
2228
"""
Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
import json
2+
from dataclasses import dataclass
3+
from os import getenv
4+
from typing import Any, Dict, List, Optional, Tuple
5+
6+
from agno.embedder.base import Embedder
7+
from agno.exceptions import AgnoError, ModelProviderError
8+
from agno.utils.log import log_error, logger
9+
10+
try:
11+
from boto3 import client as AwsClient
12+
from boto3.session import Session
13+
from botocore.exceptions import ClientError
14+
except ImportError:
15+
log_error("`boto3` not installed. Please install it via `pip install boto3`.")
16+
raise
17+
18+
19+
@dataclass
20+
class AwsBedrockEmbedder(Embedder):
21+
"""
22+
AWS Bedrock embedder.
23+
24+
To use this embedder, you need to either:
25+
1. Set the following environment variables:
26+
- AWS_ACCESS_KEY_ID
27+
- AWS_SECRET_ACCESS_KEY
28+
- AWS_REGION
29+
2. Or provide a boto3 Session object
30+
31+
Args:
32+
id (str): The model ID to use. Default is 'cohere.embed-multilingual-v3'.
33+
dimensions (Optional[int]): The dimensions of the embeddings. Default is 1024.
34+
input_type (str): Prepends special tokens to differentiate types. Options:
35+
'search_document', 'search_query', 'classification', 'clustering'. Default is 'search_query'.
36+
truncate (Optional[str]): How to handle inputs longer than the maximum token length.
37+
Options: 'NONE', 'START', 'END'. Default is 'NONE'.
38+
embedding_types (Optional[List[str]]): Types of embeddings to return. Options:
39+
'float', 'int8', 'uint8', 'binary', 'ubinary'. Default is ['float'].
40+
aws_region (Optional[str]): The AWS region to use.
41+
aws_access_key_id (Optional[str]): The AWS access key ID to use.
42+
aws_secret_access_key (Optional[str]): The AWS secret access key to use.
43+
session (Optional[Session]): A boto3 Session object to use for authentication.
44+
request_params (Optional[Dict[str, Any]]): Additional parameters to pass to the API requests.
45+
client_params (Optional[Dict[str, Any]]): Additional parameters to pass to the boto3 client.
46+
"""
47+
48+
id: str = "cohere.embed-multilingual-v3"
49+
dimensions: int = 1024 # Cohere models have 1024 dimensions by default
50+
input_type: str = "search_query"
51+
truncate: Optional[str] = None # 'NONE', 'START', or 'END'
52+
# 'float', 'int8', 'uint8', etc.
53+
embedding_types: Optional[List[str]] = None
54+
55+
aws_region: Optional[str] = None
56+
aws_access_key_id: Optional[str] = None
57+
aws_secret_access_key: Optional[str] = None
58+
session: Optional[Session] = None
59+
60+
request_params: Optional[Dict[str, Any]] = None
61+
client_params: Optional[Dict[str, Any]] = None
62+
client: Optional[AwsClient] = None
63+
64+
def get_client(self) -> AwsClient:
65+
"""
66+
Returns an AWS Bedrock client.
67+
68+
Returns:
69+
AwsClient: An instance of the AWS Bedrock client.
70+
"""
71+
if self.client is not None:
72+
return self.client
73+
74+
if self.session:
75+
self.client = self.session.client("bedrock-runtime")
76+
return self.client
77+
78+
self.aws_access_key_id = self.aws_access_key_id or getenv("AWS_ACCESS_KEY_ID")
79+
self.aws_secret_access_key = self.aws_secret_access_key or getenv("AWS_SECRET_ACCESS_KEY")
80+
self.aws_region = self.aws_region or getenv("AWS_REGION")
81+
82+
if not self.aws_access_key_id or not self.aws_secret_access_key:
83+
raise AgnoError(
84+
message="AWS credentials not found. Please set AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY environment variables or provide a boto3 session.",
85+
status_code=400,
86+
)
87+
88+
self.client = AwsClient(
89+
service_name="bedrock-runtime",
90+
region_name=self.aws_region,
91+
aws_access_key_id=self.aws_access_key_id,
92+
aws_secret_access_key=self.aws_secret_access_key,
93+
**(self.client_params or {}),
94+
)
95+
return self.client
96+
97+
def _format_request_body(self, text: str) -> str:
98+
"""
99+
Format the request body for the embedder.
100+
101+
Args:
102+
text (str): The text to embed.
103+
104+
Returns:
105+
str: The formatted request body as a JSON string.
106+
"""
107+
request_body = {
108+
"texts": [text],
109+
"input_type": self.input_type,
110+
}
111+
112+
if self.truncate:
113+
request_body["truncate"] = self.truncate
114+
115+
if self.embedding_types:
116+
request_body["embedding_types"] = self.embedding_types
117+
118+
# Add additional request parameters if provided
119+
if self.request_params:
120+
request_body.update(self.request_params)
121+
122+
return json.dumps(request_body)
123+
124+
def response(self, text: str) -> Dict[str, Any]:
125+
"""
126+
Get embeddings from AWS Bedrock for the given text.
127+
128+
Args:
129+
text (str): The text to embed.
130+
131+
Returns:
132+
Dict[str, Any]: The response from the API.
133+
"""
134+
try:
135+
body = self._format_request_body(text)
136+
response = self.get_client().invoke_model(
137+
modelId=self.id,
138+
body=body,
139+
contentType="application/json",
140+
accept="application/json",
141+
)
142+
response_body = json.loads(response["body"].read().decode("utf-8"))
143+
return response_body
144+
except ClientError as e:
145+
log_error(f"Unexpected error calling Bedrock API: {str(e)}")
146+
raise ModelProviderError(message=str(e.response), model_name="AwsBedrockEmbedder", model_id=self.id) from e
147+
except Exception as e:
148+
log_error(f"Unexpected error calling Bedrock API: {str(e)}")
149+
raise ModelProviderError(message=str(e), model_name="AwsBedrockEmbedder", model_id=self.id) from e
150+
151+
def get_embedding(self, text: str) -> List[float]:
152+
"""
153+
Get embeddings for the given text.
154+
155+
Args:
156+
text (str): The text to embed.
157+
158+
Returns:
159+
List[float]: The embedding vector.
160+
"""
161+
response = self.response(text=text)
162+
try:
163+
# Check if response contains embeddings or embeddings by type
164+
if "embeddings" in response:
165+
if isinstance(response["embeddings"], list):
166+
# Default 'float' embeddings response format
167+
return response["embeddings"][0]
168+
elif isinstance(response["embeddings"], dict):
169+
# If embeddings_types parameter was used, select float embeddings
170+
if "float" in response["embeddings"]:
171+
return response["embeddings"]["float"][0]
172+
# Fallback to the first available embedding type
173+
for embedding_type in response["embeddings"]:
174+
return response["embeddings"][embedding_type][0]
175+
logger.warning("No embeddings found in response")
176+
return []
177+
except Exception as e:
178+
logger.warning(f"Error extracting embeddings: {e}")
179+
return []
180+
181+
def get_embedding_and_usage(self, text: str) -> Tuple[List[float], Optional[Dict[str, Any]]]:
182+
"""
183+
Get embeddings and usage information for the given text.
184+
185+
Args:
186+
text (str): The text to embed.
187+
188+
Returns:
189+
Tuple[List[float], Optional[Dict[str, Any]]]: The embedding vector and usage information.
190+
"""
191+
response = self.response(text=text)
192+
193+
embedding: List[float] = []
194+
# Extract embeddings
195+
if "embeddings" in response:
196+
if isinstance(response["embeddings"], list):
197+
embedding = response["embeddings"][0]
198+
elif isinstance(response["embeddings"], dict):
199+
if "float" in response["embeddings"]:
200+
embedding = response["embeddings"]["float"][0]
201+
# Fallback to the first available embedding type
202+
else:
203+
for embedding_type in response["embeddings"]:
204+
embedding = response["embeddings"][embedding_type][0]
205+
break
206+
207+
# Extract usage metrics if available
208+
usage = None
209+
if "usage" in response:
210+
usage = response["usage"]
211+
212+
return embedding, usage

libs/agno/agno/knowledge/agent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ class AgentKnowledge(BaseModel):
2929

3030
@model_validator(mode="after")
3131
def update_reader(self) -> "AgentKnowledge":
32-
if self.reader is not None:
32+
if self.reader is not None and self.reader.chunking_strategy is None:
3333
self.reader.chunking_strategy = self.chunking_strategy
3434
return self
3535

libs/agno/tests/integration/teams/test_team_metrics.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ def test_team_metrics_basic():
5252
assert team.session_metrics.output_tokens is not None
5353
assert team.session_metrics.total_tokens is not None
5454

55+
5556
def test_team_metrics_streaming():
5657
"""Test team metrics with streaming."""
5758

libs/agno/tests/unit/reader/test_firecrawl_reader.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def test_scrape_with_chunking(mock_scrape_response):
119119
# Create reader with chunking enabled
120120
reader = FirecrawlReader()
121121
reader.chunk = True
122-
reader.chunk_size = 10 # Small chunk size to ensure multiple chunks
122+
reader.chunking_strategy.chunk_size = 10 # Small chunk size to ensure multiple chunks
123123

124124
# Create a patch for chunk_document
125125
def mock_chunk_document(doc):
@@ -209,7 +209,7 @@ def test_crawl_with_chunking(mock_crawl_response):
209209
# Create reader with chunking enabled
210210
reader = FirecrawlReader(mode="crawl")
211211
reader.chunk = True
212-
reader.chunk_size = 10 # Small chunk size to ensure multiple chunks
212+
reader.chunking_strategy.chunk_size = 10 # Small chunk size to ensure multiple chunks
213213

214214
def mock_chunk_document(doc):
215215
# Simple mock that splits into 2 chunks

0 commit comments

Comments
 (0)