Skip to content

Commit 1dbe8aa

Browse files
fix: add batch_size support to prevent embedder token limit errors
- add batch_size field to baseragconfig (default=100) - update chromadb/qdrant clients and factories to use batch_size - extract and filter batch_size from embedder config in knowledgestorage - fix large csv files exceeding embedder token limits (#3574) - remove unneeded conditional for type Co-authored-by: Vini Brasil <[email protected]>
1 parent 4ac65eb commit 1dbe8aa

File tree

12 files changed

+558
-56
lines changed

12 files changed

+558
-56
lines changed

src/crewai/knowledge/storage/knowledge_storage.py

Lines changed: 41 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from crewai.rag.chromadb.types import ChromaEmbeddingFunctionWrapper
99
from crewai.rag.config.utils import get_rag_client
1010
from crewai.rag.core.base_client import BaseClient
11-
from crewai.rag.embeddings.factory import get_embedding_function
11+
from crewai.rag.embeddings.factory import EmbedderConfig, get_embedding_function
1212
from crewai.rag.factory import create_client
1313
from crewai.rag.types import BaseRecord, SearchResult
1414
from crewai.utilities.logger import Logger
@@ -27,6 +27,7 @@ def __init__(
2727
) -> None:
2828
self.collection_name = collection_name
2929
self._client: BaseClient | None = None
30+
self._embedder_config = embedder # Store embedder config
3031

3132
warnings.filterwarnings(
3233
"ignore",
@@ -35,12 +36,29 @@ def __init__(
3536
)
3637

3738
if embedder:
38-
embedding_function = get_embedding_function(embedder)
39-
config = ChromaDBConfig(
40-
embedding_function=cast(
41-
ChromaEmbeddingFunctionWrapper, embedding_function
39+
# Cast to EmbedderConfig for type checking
40+
embedder_typed = cast(EmbedderConfig, embedder)
41+
embedding_function = get_embedding_function(embedder_typed)
42+
batch_size = None
43+
if isinstance(embedder, dict) and "config" in embedder:
44+
nested_config = embedder["config"]
45+
if isinstance(nested_config, dict):
46+
batch_size = nested_config.get("batch_size")
47+
48+
# Create config with batch_size if provided
49+
if batch_size is not None:
50+
config = ChromaDBConfig(
51+
embedding_function=cast(
52+
ChromaEmbeddingFunctionWrapper, embedding_function
53+
),
54+
batch_size=batch_size,
55+
)
56+
else:
57+
config = ChromaDBConfig(
58+
embedding_function=cast(
59+
ChromaEmbeddingFunctionWrapper, embedding_function
60+
)
4261
)
43-
)
4462
self._client = create_client(config)
4563

4664
def _get_client(self) -> BaseClient:
@@ -105,9 +123,23 @@ def save(self, documents: list[str]) -> None:
105123

106124
rag_documents: list[BaseRecord] = [{"content": doc} for doc in documents]
107125

108-
client.add_documents(
109-
collection_name=collection_name, documents=rag_documents
110-
)
126+
batch_size = None
127+
if self._embedder_config and isinstance(self._embedder_config, dict):
128+
if "config" in self._embedder_config:
129+
nested_config = self._embedder_config["config"]
130+
if isinstance(nested_config, dict):
131+
batch_size = nested_config.get("batch_size")
132+
133+
if batch_size is not None:
134+
client.add_documents(
135+
collection_name=collection_name,
136+
documents=rag_documents,
137+
batch_size=batch_size,
138+
)
139+
else:
140+
client.add_documents(
141+
collection_name=collection_name, documents=rag_documents
142+
)
111143
except Exception as e:
112144
if "dimension mismatch" in str(e).lower():
113145
Logger(verbose=True).log(

src/crewai/memory/storage/rag_storage.py

Lines changed: 41 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -66,11 +66,28 @@ def __init__(
6666
f"Error: {e}"
6767
) from e
6868

69-
config = ChromaDBConfig(
70-
embedding_function=cast(
71-
ChromaEmbeddingFunctionWrapper, embedding_function
69+
batch_size = None
70+
if (
71+
isinstance(self.embedder_config, dict)
72+
and "config" in self.embedder_config
73+
):
74+
nested_config = self.embedder_config["config"]
75+
if isinstance(nested_config, dict):
76+
batch_size = nested_config.get("batch_size")
77+
78+
if batch_size is not None:
79+
config = ChromaDBConfig(
80+
embedding_function=cast(
81+
ChromaEmbeddingFunctionWrapper, embedding_function
82+
),
83+
batch_size=batch_size,
84+
)
85+
else:
86+
config = ChromaDBConfig(
87+
embedding_function=cast(
88+
ChromaEmbeddingFunctionWrapper, embedding_function
89+
)
7290
)
73-
)
7491
self._client = create_client(config)
7592

7693
def _get_client(self) -> BaseClient:
@@ -111,7 +128,26 @@ def save(self, value: Any, metadata: dict[str, Any]) -> None:
111128
if metadata:
112129
document["metadata"] = metadata
113130

114-
client.add_documents(collection_name=collection_name, documents=[document])
131+
batch_size = None
132+
if (
133+
self.embedder_config
134+
and isinstance(self.embedder_config, dict)
135+
and "config" in self.embedder_config
136+
):
137+
nested_config = self.embedder_config["config"]
138+
if isinstance(nested_config, dict):
139+
batch_size = nested_config.get("batch_size")
140+
141+
if batch_size is not None:
142+
client.add_documents(
143+
collection_name=collection_name,
144+
documents=[document],
145+
batch_size=batch_size,
146+
)
147+
else:
148+
client.add_documents(
149+
collection_name=collection_name, documents=[document]
150+
)
115151
except Exception as e:
116152
logging.error(
117153
f"Error during {self.type} save: {e!s}\n{traceback.format_exc()}"

src/crewai/rag/chromadb/client.py

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
ChromaDBCollectionSearchParams,
1818
)
1919
from crewai.rag.chromadb.utils import (
20+
_create_batch_slice,
2021
_extract_search_params,
2122
_is_async_client,
2223
_is_sync_client,
@@ -52,6 +53,7 @@ def __init__(
5253
embedding_function: ChromaEmbeddingFunction,
5354
default_limit: int = 5,
5455
default_score_threshold: float = 0.6,
56+
default_batch_size: int = 100,
5557
) -> None:
5658
"""Initialize ChromaDBClient with client and embedding function.
5759
@@ -60,11 +62,13 @@ def __init__(
6062
embedding_function: Embedding function for text to vector conversion.
6163
default_limit: Default number of results to return in searches.
6264
default_score_threshold: Default minimum score for search results.
65+
default_batch_size: Default batch size for adding documents.
6366
"""
6467
self.client = client
6568
self.embedding_function = embedding_function
6669
self.default_limit = default_limit
6770
self.default_score_threshold = default_score_threshold
71+
self.default_batch_size = default_batch_size
6872

6973
def create_collection(
7074
self, **kwargs: Unpack[ChromaDBCollectionCreateParams]
@@ -291,6 +295,7 @@ def add_documents(self, **kwargs: Unpack[BaseCollectionAddParams]) -> None:
291295
- content: The text content (required)
292296
- doc_id: Optional unique identifier (auto-generated if missing)
293297
- metadata: Optional metadata dictionary
298+
batch_size: Optional batch size for processing documents (default: 100)
294299
295300
Raises:
296301
TypeError: If AsyncClientAPI is used instead of ClientAPI for sync operations.
@@ -305,6 +310,7 @@ def add_documents(self, **kwargs: Unpack[BaseCollectionAddParams]) -> None:
305310

306311
collection_name = kwargs["collection_name"]
307312
documents = kwargs["documents"]
313+
batch_size = kwargs.get("batch_size", self.default_batch_size)
308314

309315
if not documents:
310316
raise ValueError("Documents list cannot be empty")
@@ -315,13 +321,17 @@ def add_documents(self, **kwargs: Unpack[BaseCollectionAddParams]) -> None:
315321
)
316322

317323
prepared = _prepare_documents_for_chromadb(documents)
318-
# ChromaDB doesn't accept empty metadata dicts, so pass None if all are empty
319-
metadatas = prepared.metadatas if any(m for m in prepared.metadatas) else None
320-
collection.upsert(
321-
ids=prepared.ids,
322-
documents=prepared.texts,
323-
metadatas=metadatas,
324-
)
324+
325+
for i in range(0, len(prepared.ids), batch_size):
326+
batch_ids, batch_texts, batch_metadatas = _create_batch_slice(
327+
prepared=prepared, start_index=i, batch_size=batch_size
328+
)
329+
330+
collection.upsert(
331+
ids=batch_ids,
332+
documents=batch_texts,
333+
metadatas=batch_metadatas,
334+
)
325335

326336
async def aadd_documents(self, **kwargs: Unpack[BaseCollectionAddParams]) -> None:
327337
"""Add documents with their embeddings to a collection asynchronously.
@@ -335,6 +345,7 @@ async def aadd_documents(self, **kwargs: Unpack[BaseCollectionAddParams]) -> Non
335345
- content: The text content (required)
336346
- doc_id: Optional unique identifier (auto-generated if missing)
337347
- metadata: Optional metadata dictionary
348+
batch_size: Optional batch size for processing documents (default: 100)
338349
339350
Raises:
340351
TypeError: If ClientAPI is used instead of AsyncClientAPI for async operations.
@@ -349,6 +360,7 @@ async def aadd_documents(self, **kwargs: Unpack[BaseCollectionAddParams]) -> Non
349360

350361
collection_name = kwargs["collection_name"]
351362
documents = kwargs["documents"]
363+
batch_size = kwargs.get("batch_size", self.default_batch_size)
352364

353365
if not documents:
354366
raise ValueError("Documents list cannot be empty")
@@ -358,13 +370,17 @@ async def aadd_documents(self, **kwargs: Unpack[BaseCollectionAddParams]) -> Non
358370
embedding_function=self.embedding_function,
359371
)
360372
prepared = _prepare_documents_for_chromadb(documents)
361-
# ChromaDB doesn't accept empty metadata dicts, so pass None if all are empty
362-
metadatas = prepared.metadatas if any(m for m in prepared.metadatas) else None
363-
await collection.upsert(
364-
ids=prepared.ids,
365-
documents=prepared.texts,
366-
metadatas=metadatas,
367-
)
373+
374+
for i in range(0, len(prepared.ids), batch_size):
375+
batch_ids, batch_texts, batch_metadatas = _create_batch_slice(
376+
prepared=prepared, start_index=i, batch_size=batch_size
377+
)
378+
379+
await collection.upsert(
380+
ids=batch_ids,
381+
documents=batch_texts,
382+
metadatas=batch_metadatas,
383+
)
368384

369385
def search(
370386
self, **kwargs: Unpack[ChromaDBCollectionSearchParams]

src/crewai/rag/chromadb/factory.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,4 +41,5 @@ def create_client(config: ChromaDBConfig) -> ChromaDBClient:
4141
embedding_function=config.embedding_function,
4242
default_limit=config.limit,
4343
default_score_threshold=config.score_threshold,
44+
default_batch_size=config.batch_size,
4445
)

src/crewai/rag/chromadb/utils.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Utility functions for ChromaDB client implementation."""
22

33
import hashlib
4+
import json
45
from collections.abc import Mapping
56
from typing import Literal, TypeGuard, cast
67

@@ -72,7 +73,15 @@ def _prepare_documents_for_chromadb(
7273
if "doc_id" in doc:
7374
ids.append(doc["doc_id"])
7475
else:
75-
content_hash = hashlib.sha256(doc["content"].encode()).hexdigest()[:16]
76+
content_for_hash = doc["content"]
77+
metadata = doc.get("metadata")
78+
if metadata:
79+
metadata_str = json.dumps(metadata, sort_keys=True)
80+
content_for_hash = f"{content_for_hash}|{metadata_str}"
81+
82+
content_hash = hashlib.blake2b(
83+
content_for_hash.encode(), digest_size=32
84+
).hexdigest()
7685
ids.append(content_hash)
7786

7887
texts.append(doc["content"])
@@ -88,6 +97,32 @@ def _prepare_documents_for_chromadb(
8897
return PreparedDocuments(ids, texts, metadatas)
8998

9099

100+
def _create_batch_slice(
101+
prepared: PreparedDocuments, start_index: int, batch_size: int
102+
) -> tuple[list[str], list[str], list[Mapping[str, str | int | float | bool]] | None]:
103+
"""Create a batch slice from prepared documents.
104+
105+
Args:
106+
prepared: PreparedDocuments containing ids, texts, and metadatas.
107+
start_index: Starting index for the batch.
108+
batch_size: Size of the batch.
109+
110+
Returns:
111+
Tuple of (batch_ids, batch_texts, batch_metadatas).
112+
"""
113+
batch_end = min(start_index + batch_size, len(prepared.ids))
114+
batch_ids = prepared.ids[start_index:batch_end]
115+
batch_texts = prepared.texts[start_index:batch_end]
116+
batch_metadatas = (
117+
prepared.metadatas[start_index:batch_end] if prepared.metadatas else None
118+
)
119+
120+
if batch_metadatas and not any(m for m in batch_metadatas):
121+
batch_metadatas = None
122+
123+
return batch_ids, batch_texts, batch_metadatas
124+
125+
91126
def _extract_search_params(
92127
kwargs: ChromaDBCollectionSearchParams,
93128
) -> ExtractedSearchParams:

src/crewai/rag/config/base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,4 @@ class BaseRagConfig:
1616
embedding_function: Any | None = field(default=None)
1717
limit: int = field(default=5)
1818
score_threshold: float = field(default=0.6)
19+
batch_size: int = field(default=100)

src/crewai/rag/core/base_client.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,17 +29,19 @@ class BaseCollectionParams(TypedDict):
2929
]
3030

3131

32-
class BaseCollectionAddParams(BaseCollectionParams):
32+
class BaseCollectionAddParams(BaseCollectionParams, total=False):
3333
"""Parameters for adding documents to a collection.
3434
3535
Extends BaseCollectionParams with document-specific fields.
3636
3737
Attributes:
3838
collection_name: The name of the collection to add documents to.
3939
documents: List of BaseRecord dictionaries containing document data.
40+
batch_size: Optional batch size for processing documents to avoid token limits.
4041
"""
4142

42-
documents: list[BaseRecord]
43+
documents: Required[list[BaseRecord]]
44+
batch_size: int
4345

4446

4547
class BaseCollectionSearchParams(BaseCollectionParams, total=False):

src/crewai/rag/embeddings/factory.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,4 +244,6 @@ def get_embedding_function(
244244

245245
_inject_api_key_from_env(provider, config_dict)
246246

247+
config_dict.pop("batch_size", None)
248+
247249
return EMBEDDING_PROVIDERS[provider](**config_dict)

0 commit comments

Comments
 (0)