Skip to content

Commit 4ac65eb

Browse files
fix: support nested config format for embedder configuration
- support nested config format with embedderconfig typeddict - fix parsing for model/model_name compatibility - add validation, typing_extensions, and improved type hints - enhance embedding factory with env var injection and provider support - add tests for openai, azure, and all embedding providers - misc fixes: test file rename, updated mocking patterns
1 parent 3e97393 commit 4ac65eb

File tree

7 files changed

+926
-299
lines changed

7 files changed

+926
-299
lines changed

src/crewai/memory/storage/rag_storage.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
from crewai.rag.chromadb.types import ChromaEmbeddingFunctionWrapper
88
from crewai.rag.config.utils import get_rag_client
99
from crewai.rag.core.base_client import BaseClient
10-
from crewai.rag.embeddings.factory import get_embedding_function
10+
from crewai.rag.embeddings.factory import EmbedderConfig, get_embedding_function
11+
from crewai.rag.embeddings.types import EmbeddingOptions
1112
from crewai.rag.factory import create_client
1213
from crewai.rag.storage.base_rag_storage import BaseRAGStorage
1314
from crewai.rag.types import BaseRecord
@@ -25,7 +26,7 @@ def __init__(
2526
self,
2627
type: str,
2728
allow_reset: bool = True,
28-
embedder_config: dict[str, Any] | None = None,
29+
embedder_config: EmbeddingOptions | EmbedderConfig | None = None,
2930
crew: Any = None,
3031
path: str | None = None,
3132
) -> None:
@@ -50,6 +51,21 @@ def __init__(
5051

5152
if self.embedder_config:
5253
embedding_function = get_embedding_function(self.embedder_config)
54+
55+
try:
56+
_ = embedding_function(["test"])
57+
except Exception as e:
58+
provider = (
59+
self.embedder_config.provider
60+
if isinstance(self.embedder_config, EmbeddingOptions)
61+
else self.embedder_config.get("provider", "unknown")
62+
)
63+
raise ValueError(
64+
f"Failed to initialize embedder. Please check your configuration or connection.\n"
65+
f"Provider: {provider}\n"
66+
f"Error: {e}"
67+
) from e
68+
5369
config = ChromaDBConfig(
5470
embedding_function=cast(
5571
ChromaEmbeddingFunctionWrapper, embedding_function

src/crewai/rag/embeddings/factory.py

Lines changed: 145 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
"""Minimal embedding function factory for CrewAI."""
22

33
import os
4+
from collections.abc import Callable, MutableMapping
5+
from typing import Any, Final, Literal, TypedDict
46

57
from chromadb import EmbeddingFunction
68
from chromadb.utils.embedding_functions.amazon_bedrock_embedding_function import (
@@ -42,19 +44,116 @@
4244
from chromadb.utils.embedding_functions.text2vec_embedding_function import (
4345
Text2VecEmbeddingFunction,
4446
)
47+
from typing_extensions import NotRequired
4548

4649
from crewai.rag.embeddings.types import EmbeddingOptions
4750

51+
AllowedEmbeddingProviders = Literal[
52+
"openai",
53+
"cohere",
54+
"ollama",
55+
"huggingface",
56+
"sentence-transformer",
57+
"instructor",
58+
"google-palm",
59+
"google-generativeai",
60+
"google-vertex",
61+
"amazon-bedrock",
62+
"jina",
63+
"roboflow",
64+
"openclip",
65+
"text2vec",
66+
"onnx",
67+
]
68+
69+
70+
class EmbedderConfig(TypedDict):
71+
"""Configuration for embedding functions with nested format."""
72+
73+
provider: AllowedEmbeddingProviders
74+
config: NotRequired[dict[str, Any]]
75+
76+
77+
EMBEDDING_PROVIDERS: Final[
78+
dict[AllowedEmbeddingProviders, Callable[..., EmbeddingFunction]]
79+
] = {
80+
"openai": OpenAIEmbeddingFunction,
81+
"cohere": CohereEmbeddingFunction,
82+
"ollama": OllamaEmbeddingFunction,
83+
"huggingface": HuggingFaceEmbeddingFunction,
84+
"sentence-transformer": SentenceTransformerEmbeddingFunction,
85+
"instructor": InstructorEmbeddingFunction,
86+
"google-palm": GooglePalmEmbeddingFunction,
87+
"google-generativeai": GoogleGenerativeAiEmbeddingFunction,
88+
"google-vertex": GoogleVertexEmbeddingFunction,
89+
"amazon-bedrock": AmazonBedrockEmbeddingFunction,
90+
"jina": JinaEmbeddingFunction,
91+
"roboflow": RoboflowEmbeddingFunction,
92+
"openclip": OpenCLIPEmbeddingFunction,
93+
"text2vec": Text2VecEmbeddingFunction,
94+
"onnx": ONNXMiniLM_L6_V2,
95+
}
96+
97+
PROVIDER_ENV_MAPPING: Final[dict[AllowedEmbeddingProviders, tuple[str, str]]] = {
98+
"openai": ("OPENAI_API_KEY", "api_key"),
99+
"cohere": ("COHERE_API_KEY", "api_key"),
100+
"huggingface": ("HUGGINGFACE_API_KEY", "api_key"),
101+
"google-palm": ("GOOGLE_API_KEY", "api_key"),
102+
"google-generativeai": ("GOOGLE_API_KEY", "api_key"),
103+
"google-vertex": ("GOOGLE_API_KEY", "api_key"),
104+
"jina": ("JINA_API_KEY", "api_key"),
105+
"roboflow": ("ROBOFLOW_API_KEY", "api_key"),
106+
}
107+
108+
109+
def _inject_api_key_from_env(
110+
provider: AllowedEmbeddingProviders, config_dict: MutableMapping[str, Any]
111+
) -> None:
112+
"""Inject API key or other required configuration from environment if not explicitly provided.
113+
114+
Args:
115+
provider: The embedding provider name
116+
config_dict: The configuration dictionary to modify in-place
117+
118+
Raises:
119+
ImportError: If required libraries for certain providers are not installed
120+
ValueError: If AWS session creation fails for amazon-bedrock
121+
"""
122+
if provider in PROVIDER_ENV_MAPPING:
123+
env_var_name, config_key = PROVIDER_ENV_MAPPING[provider]
124+
if config_key not in config_dict:
125+
env_value = os.getenv(env_var_name)
126+
if env_value:
127+
config_dict[config_key] = env_value
128+
129+
if provider == "amazon-bedrock":
130+
if "session" not in config_dict:
131+
try:
132+
import boto3 # type: ignore[import]
133+
134+
config_dict["session"] = boto3.Session()
135+
except ImportError as e:
136+
raise ImportError(
137+
"boto3 is required for amazon-bedrock embeddings. "
138+
"Install it with: uv add boto3"
139+
) from e
140+
except Exception as e:
141+
raise ValueError(
142+
f"Failed to create AWS session for amazon-bedrock. "
143+
f"Ensure AWS credentials are configured. Error: {e}"
144+
) from e
145+
48146

49147
def get_embedding_function(
50-
config: EmbeddingOptions | dict | None = None,
148+
config: EmbeddingOptions | EmbedderConfig | None = None,
51149
) -> EmbeddingFunction:
52150
"""Get embedding function - delegates to ChromaDB.
53151
54152
Args:
55-
config: Optional configuration - either an EmbeddingOptions object or a dict with:
56-
- provider: The embedding provider to use (default: "openai")
57-
- Any other provider-specific parameters
153+
config: Optional configuration - either:
154+
- EmbeddingOptions: Pydantic model with flat configuration
155+
- EmbedderConfig: TypedDict with nested format {"provider": str, "config": dict}
156+
- None: Uses default OpenAI configuration
58157
59158
Returns:
60159
EmbeddingFunction instance ready for use with ChromaDB
@@ -81,31 +180,33 @@ def get_embedding_function(
81180
>>> embedder = get_embedding_function()
82181
83182
# Use Cohere with dict
84-
>>> embedder = get_embedding_function({
183+
>>> embedder = get_embedding_function(EmbedderConfig(**{
85184
... "provider": "cohere",
86-
... "api_key": "your-key",
87-
... "model_name": "embed-english-v3.0"
88-
... })
185+
... "config": {
186+
... "api_key": "your-key",
187+
... "model_name": "embed-english-v3.0"
188+
... }
189+
... }))
89190
90191
# Use with EmbeddingOptions
91192
>>> embedder = get_embedding_function(
92193
... EmbeddingOptions(provider="sentence-transformer", model_name="all-MiniLM-L6-v2")
93194
... )
94195
95-
# Use local sentence transformers (no API key needed)
96-
>>> embedder = get_embedding_function({
97-
... "provider": "sentence-transformer",
98-
... "model_name": "all-MiniLM-L6-v2"
99-
... })
100-
101-
# Use Ollama for local embeddings
102-
>>> embedder = get_embedding_function({
103-
... "provider": "ollama",
104-
... "model_name": "nomic-embed-text"
196+
# Use Azure OpenAI
197+
>>> embedder = get_embedding_function(EmbedderConfig(**{
198+
... "provider": "openai",
199+
... "config": {
200+
... "api_key": "your-azure-key",
201+
... "api_base": "https://your-resource.openai.azure.com/",
202+
... "api_type": "azure",
203+
... "api_version": "2023-05-15",
204+
... "model": "text-embedding-3-small",
205+
... "deployment_id": "your-deployment-name"
206+
... }
105207
... })
106208
107-
# Use ONNX (no API key needed)
108-
>>> embedder = get_embedding_function({
209+
>>> embedder = get_embedding_function(EmbedderConfig(**{
109210
... "provider": "onnx"
110211
... })
111212
"""
@@ -114,35 +215,33 @@ def get_embedding_function(
114215
api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small"
115216
)
116217

117-
# Handle EmbeddingOptions object
218+
provider: AllowedEmbeddingProviders
219+
config_dict: dict[str, Any]
220+
118221
if isinstance(config, EmbeddingOptions):
119222
config_dict = config.model_dump(exclude_none=True)
223+
provider = config_dict["provider"]
120224
else:
121-
config_dict = config.copy()
122-
123-
provider = config_dict.pop("provider", "openai")
124-
125-
embedding_functions = {
126-
"openai": OpenAIEmbeddingFunction,
127-
"cohere": CohereEmbeddingFunction,
128-
"ollama": OllamaEmbeddingFunction,
129-
"huggingface": HuggingFaceEmbeddingFunction,
130-
"sentence-transformer": SentenceTransformerEmbeddingFunction,
131-
"instructor": InstructorEmbeddingFunction,
132-
"google-palm": GooglePalmEmbeddingFunction,
133-
"google-generativeai": GoogleGenerativeAiEmbeddingFunction,
134-
"google-vertex": GoogleVertexEmbeddingFunction,
135-
"amazon-bedrock": AmazonBedrockEmbeddingFunction,
136-
"jina": JinaEmbeddingFunction,
137-
"roboflow": RoboflowEmbeddingFunction,
138-
"openclip": OpenCLIPEmbeddingFunction,
139-
"text2vec": Text2VecEmbeddingFunction,
140-
"onnx": ONNXMiniLM_L6_V2,
141-
}
142-
143-
if provider not in embedding_functions:
225+
provider = config["provider"]
226+
nested: dict[str, Any] = config.get("config", {})
227+
228+
if not nested and len(config) > 1:
229+
raise ValueError(
230+
"Invalid embedder configuration format. "
231+
"Configuration must be nested under a 'config' key. "
232+
"Example: {'provider': 'openai', 'config': {'api_key': '...', 'model': '...'}}"
233+
)
234+
235+
config_dict = dict(nested)
236+
if "model" in config_dict and "model_name" not in config_dict:
237+
config_dict["model_name"] = config_dict.pop("model")
238+
239+
if provider not in EMBEDDING_PROVIDERS:
144240
raise ValueError(
145241
f"Unsupported provider: {provider}. "
146-
f"Available providers: {list(embedding_functions.keys())}"
242+
f"Available providers: {list(EMBEDDING_PROVIDERS.keys())}"
147243
)
148-
return embedding_functions[provider](**config_dict)
244+
245+
_inject_api_key_from_env(provider, config_dict)
246+
247+
return EMBEDDING_PROVIDERS[provider](**config_dict)

src/crewai/rag/storage/base_rag_storage.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
from abc import ABC, abstractmethod
22
from typing import Any
33

4+
from crewai.rag.embeddings.factory import EmbedderConfig
5+
from crewai.rag.embeddings.types import EmbeddingOptions
6+
47

58
class BaseRAGStorage(ABC):
69
"""
@@ -13,7 +16,7 @@ def __init__(
1316
self,
1417
type: str,
1518
allow_reset: bool = True,
16-
embedder_config: dict[str, Any] | None = None,
19+
embedder_config: EmbeddingOptions | EmbedderConfig | None = None,
1720
crew: Any = None,
1821
):
1922
self.type = type

0 commit comments

Comments
 (0)