Skip to content

Commit 2a8393f

Browse files
authored
INTPYTHON-655 Infer dimensions from embedding if not provided (#182)
1 parent dd08e84 commit 2a8393f

File tree

3 files changed

+61
-10
lines changed

3 files changed

+61
-10
lines changed

libs/langchain-mongodb/langchain_mongodb/vectorstores.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ def __init__(
208208
embedding_key: str = "embedding",
209209
relevance_score_fn: str = "cosine",
210210
dimensions: int = -1,
211-
auto_create_index: bool = True,
211+
auto_create_index: bool | None = None,
212212
auto_index_timeout: int = 15,
213213
**kwargs: Any,
214214
):
@@ -222,8 +222,9 @@ def __init__(
222222
embedding_key: Field that will contain the embedding for each document
223223
relevance_score_fn: The similarity score used for the index
224224
Currently supported: 'euclidean', 'cosine', and 'dotProduct'
225-
dimensions: Number of dimensions in embedding. If the value is set and
226-
the index does not exist, an index will be created.
225+
auto_create_index: Whether to automatically create an index if it does not exist.
226+
dimensions: Number of dimensions in embedding. If the value is not provided, and `auto_create_index`
227+
is `true`, the value will be inferred.
227228
auto_index_timeout: Timeout in seconds to wait for an auto-created index
228229
to be ready.
229230
"""
@@ -234,18 +235,21 @@ def __init__(
234235
self._embedding_key = embedding_key
235236
self._relevance_score_fn = relevance_score_fn
236237

237-
if not auto_create_index or dimensions == -1:
238+
if auto_create_index is False:
238239
return
240+
if auto_create_index is None and dimensions == -1:
241+
return
242+
if dimensions == -1:
243+
dimensions = len(embedding.embed_query("foo"))
244+
239245
coll = self._collection
240-
if not any(
241-
[ix["name"] == self._index_name for ix in coll.list_search_indexes()]
242-
):
246+
if not any([ix["name"] == index_name for ix in coll.list_search_indexes()]):
243247
create_vector_search_index(
244248
collection=coll,
245-
index_name=self._index_name,
249+
index_name=index_name,
246250
dimensions=dimensions,
247-
path=self._embedding_key,
248-
similarity=self._relevance_score_fn,
251+
path=embedding_key,
252+
similarity=relevance_score_fn,
249253
wait_until_complete=auto_index_timeout,
250254
)
251255

libs/langchain-mongodb/tests/unit_tests/test_vectorstores.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ def test_from_documents(
9898
collection=collection,
9999
index_name=INDEX_NAME,
100100
)
101+
# TODO: test how DIMS is handled here.
101102
self._validate_search(
102103
vectorstore, collection, metadata=documents[2].metadata["c"]
103104
)
@@ -191,3 +192,35 @@ def test_mmr(
191192
assert len(output) == len(texts)
192193
assert output[0].page_content == "foo"
193194
assert output[1].page_content != "foo"
195+
196+
def test_auto_create_index(
197+
self, embedding_openai: Embeddings, collection: MockCollection
198+
) -> None:
199+
# Explicit auto_create_index
200+
assert len(collection._search_indexes) == 0
201+
_ = MongoDBAtlasVectorSearch(
202+
embedding=embedding_openai,
203+
collection=collection,
204+
index_name=INDEX_NAME,
205+
auto_create_index=True,
206+
)
207+
assert len(collection._search_indexes) == 1
208+
209+
# Explicit dimensions
210+
collection._search_indexes = []
211+
_ = MongoDBAtlasVectorSearch(
212+
embedding=embedding_openai,
213+
collection=collection,
214+
index_name=INDEX_NAME,
215+
dimensions=10,
216+
)
217+
assert len(collection._search_indexes) == 1
218+
219+
# Does not auto-create
220+
collection._search_indexes = []
221+
_ = MongoDBAtlasVectorSearch(
222+
embedding=embedding_openai,
223+
collection=collection,
224+
index_name=INDEX_NAME,
225+
)
226+
assert len(collection._search_indexes) == 0

libs/langchain-mongodb/tests/utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from pydantic import model_validator
2424
from pymongo import MongoClient
2525
from pymongo.collection import Collection
26+
from pymongo.operations import SearchIndexModel
2627
from pymongo.results import BulkWriteResult, DeleteResult, InsertManyResult
2728

2829
from langchain_mongodb import MongoDBAtlasVectorSearch
@@ -250,9 +251,11 @@ class MockCollection(Collection):
250251

251252
def __init__(self, database: MockDatabase | None = None) -> None:
252253
self._data = []
254+
self._name = "test"
253255
self.is_closed = False
254256
self._aggregate_result = []
255257
self._insert_result = None
258+
self._search_indexes = []
256259
self._simulate_cache_aggregation_query = False
257260
self._database = database or MockDatabase() # type:ignore[assignment]
258261

@@ -263,6 +266,17 @@ def database(self):
263266
def close(self):
264267
self.is_closed = True
265268

269+
def list_search_indexes(self, name=None, session=None, comment=None, **kwargs):
270+
return [
271+
dict(name=idx.document["name"], status="READY")
272+
for idx in self._search_indexes
273+
]
274+
275+
def create_search_index(self, model, session=None, comment=None, **kwargs):
276+
if not isinstance(model, SearchIndexModel):
277+
model = SearchIndexModel(model, name=f"test{len(self._search_indexes)}")
278+
self._search_indexes.append(model)
279+
266280
def delete_many(self, *args, **kwargs) -> DeleteResult: # type: ignore
267281
old_len = len(self._data)
268282
self._data = []

0 commit comments

Comments
 (0)