Skip to content

Commit 1442f3e

Browse files
devin-ai-integration[bot]João
andcommitted
fix: add Watson embedding support to factory
- Add Watson to EmbeddingProvider type definition - Implement _create_watson_embedding_function in factory.py - Add Watson to embedding_functions dictionary - Add comprehensive tests for Watson embedding functionality - Ensure proper error handling for missing IBM Watson dependencies Fixes #3582 Co-Authored-By: João <joao@crewai.com>
1 parent 3e97393 commit 1442f3e

File tree

3 files changed

+122
-0
lines changed

3 files changed

+122
-0
lines changed

src/crewai/rag/embeddings/factory.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,51 @@
4646
from crewai.rag.embeddings.types import EmbeddingOptions
4747

4848

49+
def _create_watson_embedding_function(**config_dict) -> EmbeddingFunction:
50+
"""Create Watson embedding function with proper error handling."""
51+
try:
52+
import ibm_watsonx_ai.foundation_models as watson_models # type: ignore[import-not-found]
53+
from ibm_watsonx_ai import Credentials # type: ignore[import-not-found]
54+
from ibm_watsonx_ai.metanames import ( # type: ignore[import-not-found]
55+
EmbedTextParamsMetaNames as EmbedParams,
56+
)
57+
except ImportError as e:
58+
raise ImportError(
59+
"IBM Watson dependencies are not installed. Please install them to use Watson embedding."
60+
) from e
61+
62+
class WatsonEmbeddingFunction(EmbeddingFunction):
63+
def __init__(self, **kwargs):
64+
self.config = kwargs
65+
66+
def __call__(self, input):
67+
if isinstance(input, str):
68+
input = [input]
69+
70+
embed_params = {
71+
EmbedParams.TRUNCATE_INPUT_TOKENS: 3,
72+
EmbedParams.RETURN_OPTIONS: {"input_text": True},
73+
}
74+
75+
embedding = watson_models.Embeddings(
76+
model_id=self.config.get("model_name") or self.config.get("model"),
77+
params=embed_params,
78+
credentials=Credentials(
79+
api_key=self.config.get("api_key"),
80+
url=self.config.get("api_url") or self.config.get("url")
81+
),
82+
project_id=self.config.get("project_id"),
83+
)
84+
85+
try:
86+
embeddings = embedding.embed_documents(input)
87+
return embeddings
88+
except Exception as e:
89+
raise RuntimeError(f"Error during Watson embedding: {e}") from e
90+
91+
return WatsonEmbeddingFunction(**config_dict)
92+
93+
4994
def get_embedding_function(
5095
config: EmbeddingOptions | dict | None = None,
5196
) -> EmbeddingFunction:
@@ -75,6 +120,7 @@ def get_embedding_function(
75120
- openclip: OpenCLIP embeddings for multimodal tasks
76121
- text2vec: Text2Vec embeddings
77122
- onnx: ONNX MiniLM-L6-v2 (no API key needed, included with ChromaDB)
123+
- watson: IBM Watson embeddings
78124
79125
Examples:
80126
# Use default OpenAI embedding
@@ -108,6 +154,15 @@ def get_embedding_function(
108154
>>> embedder = get_embedding_function({
109155
... "provider": "onnx"
110156
... })
157+
158+
# Use Watson embeddings
159+
>>> embedder = get_embedding_function({
160+
... "provider": "watson",
161+
... "api_key": "your-watson-api-key",
162+
... "api_url": "your-watson-url",
163+
... "project_id": "your-project-id",
164+
... "model_name": "ibm/slate-125m-english-rtrvr"
165+
... })
111166
"""
112167
if config is None:
113168
return OpenAIEmbeddingFunction(
@@ -138,6 +193,7 @@ def get_embedding_function(
138193
"openclip": OpenCLIPEmbeddingFunction,
139194
"text2vec": Text2VecEmbeddingFunction,
140195
"onnx": ONNXMiniLM_L6_V2,
196+
"watson": _create_watson_embedding_function,
141197
}
142198

143199
if provider not in embedding_functions:

src/crewai/rag/embeddings/types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
"openclip",
2323
"text2vec",
2424
"onnx",
25+
"watson",
2526
]
2627
"""Supported embedding providers.
2728

tests/rag/embeddings/test_factory_enhanced.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,3 +248,68 @@ def test_get_embedding_function_instructor() -> None:
248248

249249
mock_instructor.assert_called_once_with(model_name="hkunlp/instructor-large")
250250
assert result == mock_instance
251+
252+
253+
def test_get_embedding_function_watson() -> None:
254+
"""Test Watson embedding function."""
255+
with patch("crewai.rag.embeddings.factory._create_watson_embedding_function") as mock_watson:
256+
mock_instance = MagicMock()
257+
mock_watson.return_value = mock_instance
258+
259+
config = {
260+
"provider": "watson",
261+
"api_key": "watson-api-key",
262+
"api_url": "https://watson-url.com",
263+
"project_id": "watson-project-id",
264+
"model_name": "ibm/slate-125m-english-rtrvr",
265+
}
266+
267+
result = get_embedding_function(config)
268+
269+
mock_watson.assert_called_once_with(
270+
api_key="watson-api-key",
271+
api_url="https://watson-url.com",
272+
project_id="watson-project-id",
273+
model_name="ibm/slate-125m-english-rtrvr",
274+
)
275+
assert result == mock_instance
276+
277+
278+
def test_get_embedding_function_watson_missing_dependencies() -> None:
279+
"""Test Watson embedding function with missing dependencies."""
280+
with patch("crewai.rag.embeddings.factory._create_watson_embedding_function") as mock_watson:
281+
mock_watson.side_effect = ImportError(
282+
"IBM Watson dependencies are not installed. Please install them to use Watson embedding."
283+
)
284+
285+
config = {
286+
"provider": "watson",
287+
"api_key": "watson-api-key",
288+
"api_url": "https://watson-url.com",
289+
"project_id": "watson-project-id",
290+
"model_name": "ibm/slate-125m-english-rtrvr",
291+
}
292+
293+
with pytest.raises(ImportError, match="IBM Watson dependencies are not installed"):
294+
get_embedding_function(config)
295+
296+
297+
def test_get_embedding_function_watson_with_embedding_options() -> None:
298+
"""Test Watson embedding function with EmbeddingOptions object."""
299+
with patch("crewai.rag.embeddings.factory._create_watson_embedding_function") as mock_watson:
300+
mock_instance = MagicMock()
301+
mock_watson.return_value = mock_instance
302+
303+
options = EmbeddingOptions(
304+
provider="watson",
305+
api_key="watson-key",
306+
model_name="ibm/slate-125m-english-rtrvr"
307+
)
308+
309+
result = get_embedding_function(options)
310+
311+
call_kwargs = mock_watson.call_args.kwargs
312+
assert "api_key" in call_kwargs
313+
assert call_kwargs["api_key"].get_secret_value() == "watson-key"
314+
assert call_kwargs["model_name"] == "ibm/slate-125m-english-rtrvr"
315+
assert result == mock_instance

0 commit comments

Comments
 (0)