11"""Minimal embedding function factory for CrewAI."""
22
33import os
4+ from collections .abc import Callable , MutableMapping
5+ from typing import Any , Final , Literal , TypedDict
46
57from chromadb import EmbeddingFunction
68from chromadb .utils .embedding_functions .amazon_bedrock_embedding_function import (
4244from chromadb .utils .embedding_functions .text2vec_embedding_function import (
4345 Text2VecEmbeddingFunction ,
4446)
47+ from typing_extensions import NotRequired
4548
4649from crewai .rag .embeddings .types import EmbeddingOptions
4750
51+ AllowedEmbeddingProviders = Literal [
52+ "openai" ,
53+ "cohere" ,
54+ "ollama" ,
55+ "huggingface" ,
56+ "sentence-transformer" ,
57+ "instructor" ,
58+ "google-palm" ,
59+ "google-generativeai" ,
60+ "google-vertex" ,
61+ "amazon-bedrock" ,
62+ "jina" ,
63+ "roboflow" ,
64+ "openclip" ,
65+ "text2vec" ,
66+ "onnx" ,
67+ ]
68+
69+
70+ class EmbedderConfig (TypedDict ):
71+ """Configuration for embedding functions with nested format."""
72+
73+ provider : AllowedEmbeddingProviders
74+ config : NotRequired [dict [str , Any ]]
75+
76+
77+ EMBEDDING_PROVIDERS : Final [
78+ dict [AllowedEmbeddingProviders , Callable [..., EmbeddingFunction ]]
79+ ] = {
80+ "openai" : OpenAIEmbeddingFunction ,
81+ "cohere" : CohereEmbeddingFunction ,
82+ "ollama" : OllamaEmbeddingFunction ,
83+ "huggingface" : HuggingFaceEmbeddingFunction ,
84+ "sentence-transformer" : SentenceTransformerEmbeddingFunction ,
85+ "instructor" : InstructorEmbeddingFunction ,
86+ "google-palm" : GooglePalmEmbeddingFunction ,
87+ "google-generativeai" : GoogleGenerativeAiEmbeddingFunction ,
88+ "google-vertex" : GoogleVertexEmbeddingFunction ,
89+ "amazon-bedrock" : AmazonBedrockEmbeddingFunction ,
90+ "jina" : JinaEmbeddingFunction ,
91+ "roboflow" : RoboflowEmbeddingFunction ,
92+ "openclip" : OpenCLIPEmbeddingFunction ,
93+ "text2vec" : Text2VecEmbeddingFunction ,
94+ "onnx" : ONNXMiniLM_L6_V2 ,
95+ }
96+
97+ PROVIDER_ENV_MAPPING : Final [dict [AllowedEmbeddingProviders , tuple [str , str ]]] = {
98+ "openai" : ("OPENAI_API_KEY" , "api_key" ),
99+ "cohere" : ("COHERE_API_KEY" , "api_key" ),
100+ "huggingface" : ("HUGGINGFACE_API_KEY" , "api_key" ),
101+ "google-palm" : ("GOOGLE_API_KEY" , "api_key" ),
102+ "google-generativeai" : ("GOOGLE_API_KEY" , "api_key" ),
103+ "google-vertex" : ("GOOGLE_API_KEY" , "api_key" ),
104+ "jina" : ("JINA_API_KEY" , "api_key" ),
105+ "roboflow" : ("ROBOFLOW_API_KEY" , "api_key" ),
106+ }
107+
108+
109+ def _inject_api_key_from_env (
110+ provider : AllowedEmbeddingProviders , config_dict : MutableMapping [str , Any ]
111+ ) -> None :
112+ """Inject API key or other required configuration from environment if not explicitly provided.
113+
114+ Args:
115+ provider: The embedding provider name
116+ config_dict: The configuration dictionary to modify in-place
117+
118+ Raises:
119+ ImportError: If required libraries for certain providers are not installed
120+ ValueError: If AWS session creation fails for amazon-bedrock
121+ """
122+ if provider in PROVIDER_ENV_MAPPING :
123+ env_var_name , config_key = PROVIDER_ENV_MAPPING [provider ]
124+ if config_key not in config_dict :
125+ env_value = os .getenv (env_var_name )
126+ if env_value :
127+ config_dict [config_key ] = env_value
128+
129+ if provider == "amazon-bedrock" :
130+ if "session" not in config_dict :
131+ try :
132+ import boto3 # type: ignore[import]
133+
134+ config_dict ["session" ] = boto3 .Session ()
135+ except ImportError as e :
136+ raise ImportError (
137+ "boto3 is required for amazon-bedrock embeddings. "
138+ "Install it with: uv add boto3"
139+ ) from e
140+ except Exception as e :
141+ raise ValueError (
142+ f"Failed to create AWS session for amazon-bedrock. "
143+ f"Ensure AWS credentials are configured. Error: { e } "
144+ ) from e
145+
48146
49147def get_embedding_function (
50- config : EmbeddingOptions | dict | None = None ,
148+ config : EmbeddingOptions | EmbedderConfig | None = None ,
51149) -> EmbeddingFunction :
52150 """Get embedding function - delegates to ChromaDB.
53151
54152 Args:
55- config: Optional configuration - either an EmbeddingOptions object or a dict with:
56- - provider: The embedding provider to use (default: "openai")
57- - Any other provider-specific parameters
153+ config: Optional configuration - either:
154+ - EmbeddingOptions: Pydantic model with flat configuration
155+ - EmbedderConfig: TypedDict with nested format {"provider": str, "config": dict}
156+ - None: Uses default OpenAI configuration
58157
59158 Returns:
60159 EmbeddingFunction instance ready for use with ChromaDB
@@ -81,31 +180,33 @@ def get_embedding_function(
81180 >>> embedder = get_embedding_function()
82181
83182 # Use Cohere with dict
84- >>> embedder = get_embedding_function({
183+ >>> embedder = get_embedding_function(EmbedderConfig(** {
85184 ... "provider": "cohere",
86- ... "api_key": "your-key",
87- ... "model_name": "embed-english-v3.0"
88- ... })
185+ ... "config": {
186+ ... "api_key": "your-key",
187+ ... "model_name": "embed-english-v3.0"
188+ ... }
189+ ... }))
89190
90191 # Use with EmbeddingOptions
91192 >>> embedder = get_embedding_function(
92193 ... EmbeddingOptions(provider="sentence-transformer", model_name="all-MiniLM-L6-v2")
93194 ... )
94195
95- # Use local sentence transformers (no API key needed)
96- >>> embedder = get_embedding_function({
97- ... "provider": "sentence-transformer",
98- ... "model_name": "all-MiniLM-L6-v2"
99- ... })
100-
101- # Use Ollama for local embeddings
102- >>> embedder = get_embedding_function({
103- ... "provider": "ollama",
104- ... "model_name": "nomic-embed-text"
196+ # Use Azure OpenAI
197+ >>> embedder = get_embedding_function(EmbedderConfig(**{
198+ ... "provider": "openai",
199+ ... "config": {
200+ ... "api_key": "your-azure-key",
201+ ... "api_base": "https://your-resource.openai.azure.com/",
202+ ... "api_type": "azure",
203+ ... "api_version": "2023-05-15",
204+ ... "model": "text-embedding-3-small",
205+ ... "deployment_id": "your-deployment-name"
206+ ... }
105207 ... })
106208
107- # Use ONNX (no API key needed)
108- >>> embedder = get_embedding_function({
209+ >>> embedder = get_embedding_function(EmbedderConfig(**{
109210 ... "provider": "onnx"
110211 ... })
111212 """
@@ -114,35 +215,33 @@ def get_embedding_function(
114215 api_key = os .getenv ("OPENAI_API_KEY" ), model_name = "text-embedding-3-small"
115216 )
116217
117- # Handle EmbeddingOptions object
218+ provider : AllowedEmbeddingProviders
219+ config_dict : dict [str , Any ]
220+
118221 if isinstance (config , EmbeddingOptions ):
119222 config_dict = config .model_dump (exclude_none = True )
223+ provider = config_dict ["provider" ]
120224 else :
121- config_dict = config .copy ()
122-
123- provider = config_dict .pop ("provider" , "openai" )
124-
125- embedding_functions = {
126- "openai" : OpenAIEmbeddingFunction ,
127- "cohere" : CohereEmbeddingFunction ,
128- "ollama" : OllamaEmbeddingFunction ,
129- "huggingface" : HuggingFaceEmbeddingFunction ,
130- "sentence-transformer" : SentenceTransformerEmbeddingFunction ,
131- "instructor" : InstructorEmbeddingFunction ,
132- "google-palm" : GooglePalmEmbeddingFunction ,
133- "google-generativeai" : GoogleGenerativeAiEmbeddingFunction ,
134- "google-vertex" : GoogleVertexEmbeddingFunction ,
135- "amazon-bedrock" : AmazonBedrockEmbeddingFunction ,
136- "jina" : JinaEmbeddingFunction ,
137- "roboflow" : RoboflowEmbeddingFunction ,
138- "openclip" : OpenCLIPEmbeddingFunction ,
139- "text2vec" : Text2VecEmbeddingFunction ,
140- "onnx" : ONNXMiniLM_L6_V2 ,
141- }
142-
143- if provider not in embedding_functions :
225+ provider = config ["provider" ]
226+ nested : dict [str , Any ] = config .get ("config" , {})
227+
228+ if not nested and len (config ) > 1 :
229+ raise ValueError (
230+ "Invalid embedder configuration format. "
231+ "Configuration must be nested under a 'config' key. "
232+ "Example: {'provider': 'openai', 'config': {'api_key': '...', 'model': '...'}}"
233+ )
234+
235+ config_dict = dict (nested )
236+ if "model" in config_dict and "model_name" not in config_dict :
237+ config_dict ["model_name" ] = config_dict .pop ("model" )
238+
239+ if provider not in EMBEDDING_PROVIDERS :
144240 raise ValueError (
145241 f"Unsupported provider: { provider } . "
146- f"Available providers: { list (embedding_functions .keys ())} "
242+ f"Available providers: { list (EMBEDDING_PROVIDERS .keys ())} "
147243 )
148- return embedding_functions [provider ](** config_dict )
244+
245+ _inject_api_key_from_env (provider , config_dict )
246+
247+ return EMBEDDING_PROVIDERS [provider ](** config_dict )
0 commit comments