Skip to content

Commit 73e932b

Browse files
fix: update embedding functions to inherit from chromadb callable
1 parent 12fa7e2 commit 73e932b

File tree

4 files changed

+25
-5
lines changed

4 files changed

+25
-5
lines changed

src/crewai/rag/core/base_embeddings_callable.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,3 +140,10 @@ def wrapped_call(self: EmbeddingFunction[D], input: D) -> Embeddings:
140140
return validate_embeddings(normalized)
141141

142142
cls.__call__ = wrapped_call # type: ignore[method-assign]
143+
144+
def embed_query(self, input: D) -> Embeddings:
145+
"""
146+
Get the embeddings for a query input.
147+
This method is optional, and if not implemented, the default behavior is to call __call__.
148+
"""
149+
return self.__call__(input=input)

src/crewai/rag/embeddings/providers/ibm/embedding_callable.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,9 @@
22

33
from typing import cast
44

5+
from chromadb.api.types import Documents, EmbeddingFunction, Embeddings
56
from typing_extensions import Unpack
67

7-
from crewai.rag.core.base_embeddings_callable import EmbeddingFunction
8-
from crewai.rag.core.types import Documents, Embeddings
98
from crewai.rag.embeddings.providers.ibm.types import WatsonXProviderConfig
109

1110

@@ -18,8 +17,14 @@ def __init__(self, **kwargs: Unpack[WatsonXProviderConfig]) -> None:
1817
Args:
1918
**kwargs: Configuration parameters for WatsonX Embeddings and Credentials.
2019
"""
20+
super().__init__(**kwargs)
2121
self._config = kwargs
2222

23+
@staticmethod
24+
def name() -> str:
25+
"""Return the name of the embedding function for ChromaDB compatibility."""
26+
return "watsonx"
27+
2328
def __call__(self, input: Documents) -> Embeddings:
2429
"""Generate embeddings for input documents.
2530

src/crewai/rag/embeddings/providers/voyageai/embedding_callable.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,9 @@
22

33
from typing import cast
44

5+
from chromadb.api.types import Documents, EmbeddingFunction, Embeddings
56
from typing_extensions import Unpack
67

7-
from crewai.rag.core.base_embeddings_callable import EmbeddingFunction
8-
from crewai.rag.core.types import Documents, Embeddings
98
from crewai.rag.embeddings.providers.voyageai.types import VoyageAIProviderConfig
109

1110

@@ -33,6 +32,11 @@ def __init__(self, **kwargs: Unpack[VoyageAIProviderConfig]) -> None:
3332
timeout=kwargs.get("timeout"),
3433
)
3534

35+
@staticmethod
36+
def name() -> str:
37+
"""Return the name of the embedding function for ChromaDB compatibility."""
38+
return "voyageai"
39+
3640
def __call__(self, input: Documents) -> Embeddings:
3741
"""Generate embeddings for input documents.
3842

src/crewai/rag/embeddings/types.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,10 @@
1111
VertexAIProviderSpec,
1212
)
1313
from crewai.rag.embeddings.providers.huggingface.types import HuggingFaceProviderSpec
14-
from crewai.rag.embeddings.providers.ibm.types import WatsonXProviderSpec
14+
from crewai.rag.embeddings.providers.ibm.types import (
15+
WatsonProviderSpec,
16+
WatsonXProviderSpec,
17+
)
1518
from crewai.rag.embeddings.providers.instructor.types import InstructorProviderSpec
1619
from crewai.rag.embeddings.providers.jina.types import JinaProviderSpec
1720
from crewai.rag.embeddings.providers.microsoft.types import AzureProviderSpec
@@ -44,6 +47,7 @@
4447
| Text2VecProviderSpec
4548
| VertexAIProviderSpec
4649
| VoyageAIProviderSpec
50+
| WatsonProviderSpec # Deprecated, use WatsonXProviderSpec
4751
| WatsonXProviderSpec
4852
)
4953

0 commit comments

Comments
 (0)