Skip to content

Commit 95b2f58

Browse files
committed
feat(cli): add a note to get_embedding_function error message. Fix #46
1 parent 43df44c commit 95b2f58

File tree

2 files changed

+40
-1
lines changed

2 files changed

+40
-1
lines changed

src/vectorcode/common.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def get_collection_name(full_path: str) -> str:
134134
return collection_id
135135

136136

137-
def get_embedding_function(configs: Config) -> chromadb.EmbeddingFunction:
137+
def get_embedding_function(configs: Config) -> chromadb.EmbeddingFunction | None:
138138
try:
139139
return getattr(embedding_functions, configs.embedding_function)(
140140
**configs.embedding_params
@@ -145,6 +145,16 @@ def get_embedding_function(configs: Config) -> chromadb.EmbeddingFunction:
145145
file=sys.stderr,
146146
)
147147
return embedding_functions.SentenceTransformerEmbeddingFunction()
148+
except Exception as e:
149+
print(
150+
f"Failed to use {configs.embedding_function} with the following error:",
151+
file=sys.stderr,
152+
)
153+
e.add_note(
154+
"\nFor errors caused by missing dependency, consult the documentation of pipx (or whatever package manager that you installed VectorCode with) for instructions to inject libraries into the virtual environment."
155+
)
156+
157+
raise
148158

149159

150160
__COLLECTION_CACHE: dict[str, AsyncCollection] = {}

tests/test_common.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import pytest
1010
from chromadb.api import AsyncClientAPI
1111
from chromadb.api.models.AsyncCollection import AsyncCollection
12+
from chromadb.utils import embedding_functions
1213

1314
from vectorcode.cli_utils import Config
1415
from vectorcode.common import (
@@ -68,6 +69,34 @@ def test_get_embedding_function():
6869
assert "SentenceTransformerEmbeddingFunction" in str(type(embedding_function))
6970

7071

72+
def test_get_embedding_function_init_exception():
73+
# Test when the embedding function exists but raises an error during initialization
74+
config = Config(
75+
embedding_function="SentenceTransformerEmbeddingFunction",
76+
embedding_params={"model_name": "non_existent_model_should_cause_error"},
77+
)
78+
79+
# Mock SentenceTransformerEmbeddingFunction.__init__ to raise a generic exception
80+
with patch.object(
81+
embedding_functions, "SentenceTransformerEmbeddingFunction", autospec=True
82+
) as mock_stef:
83+
# Simulate an error during the embedding function's __init__
84+
mock_stef.side_effect = Exception("Simulated initialization error")
85+
86+
with pytest.raises(Exception) as excinfo:
87+
get_embedding_function(config)
88+
89+
# Check if the raised exception is the one we simulated
90+
assert "Simulated initialization error" in str(excinfo.value)
91+
# Check if the additional note was added
92+
assert "For errors caused by missing dependency" in excinfo.value.__notes__[0]
93+
94+
# Verify that the constructor was called with the correct parameters
95+
mock_stef.assert_called_once_with(
96+
model_name="non_existent_model_should_cause_error"
97+
)
98+
99+
71100
@pytest.mark.asyncio
72101
async def test_try_server():
73102
# This test requires a server to be running, so it's difficult to make it truly isolated.

0 commit comments

Comments
 (0)