Skip to content

Commit 4876051

Browse files
authored
Add multi_field support for $search operation (#166)
Linked to #165 All tests checked.
1 parent 37aa82d commit 4876051

File tree

5 files changed

+299
-13
lines changed

5 files changed

+299
-13
lines changed

libs/langchain-mongodb/langchain_mongodb/index.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import logging
44
from time import monotonic, sleep
5-
from typing import Any, Callable, Dict, List, Optional
5+
from typing import Any, Callable, Dict, List, Optional, Union
66

77
from pymongo.collection import Collection
88
from pymongo.operations import SearchIndexModel
@@ -202,7 +202,7 @@ def _wait_for_predicate(
202202
def create_fulltext_search_index(
203203
collection: Collection,
204204
index_name: str,
205-
field: str,
205+
field: Union[str, List[str]],
206206
*,
207207
wait_until_complete: Optional[float] = None,
208208
**kwargs: Any,
@@ -222,9 +222,11 @@ def create_fulltext_search_index(
222222
if collection.name not in collection.database.list_collection_names():
223223
collection.database.create_collection(collection.name)
224224

225-
definition = {
226-
"mappings": {"dynamic": False, "fields": {field: [{"type": "string"}]}}
227-
}
225+
if isinstance(field, str):
226+
fields_definition = {field: [{"type": "string"}]}
227+
else:
228+
fields_definition = {f: [{"type": "string"}] for f in field}
229+
definition = {"mappings": {"dynamic": False, "fields": fields_definition}}
228230
result = collection.create_search_index(
229231
SearchIndexModel(
230232
definition=definition,

libs/langchain-mongodb/langchain_mongodb/pipelines.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,12 @@
77
- `Filter Example <https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/#atlas-vector-search-pre-filter>`_
88
"""
99

10-
from typing import Any, Dict, List, Optional
10+
from typing import Any, Dict, List, Optional, Union
1111

1212

1313
def text_search_stage(
1414
query: str,
15-
search_field: str,
15+
search_field: Union[str, List[str]],
1616
index_name: str,
1717
limit: Optional[int] = None,
1818
filter: Optional[Dict[str, Any]] = None,

libs/langchain-mongodb/langchain_mongodb/retrievers/full_text_search.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Annotated, Any, Dict, List, Optional
1+
from typing import Annotated, Any, Dict, List, Optional, Union
22

33
from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun
44
from langchain_core.documents import Document
@@ -17,7 +17,7 @@ class MongoDBAtlasFullTextSearchRetriever(BaseRetriever):
1717
"""MongoDB Collection on an Atlas cluster"""
1818
search_index_name: str
1919
"""Atlas Search Index name"""
20-
search_field: str
20+
search_field: Union[str, List[str]]
2121
"""Collection field that contains the text to be searched. It must be indexed"""
2222
k: Optional[int] = None
2323
"""Number of documents to return. Default is no limit"""
@@ -61,7 +61,11 @@ def _get_relevant_documents(
6161
# Formatting
6262
docs = []
6363
for res in cursor:
64-
text = res.pop(self.search_field)
64+
text = (
65+
res.pop(self.search_field)
66+
if isinstance(self.search_field, str)
67+
else res.pop(self.search_field[0])
68+
)
6569
make_serializable(res)
6670
docs.append(Document(page_content=text, metadata=res))
6771
return docs

libs/langchain-mongodb/langchain_mongodb/vectorstores.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ def __init__(
204204
collection: Collection[Dict[str, Any]],
205205
embedding: Embeddings,
206206
index_name: str = "vector_index",
207-
text_key: str = "text",
207+
text_key: Union[str, List[str]] = "text",
208208
embedding_key: str = "embedding",
209209
relevance_score_fn: str = "cosine",
210210
dimensions: int = -1,
@@ -216,7 +216,8 @@ def __init__(
216216
Args:
217217
collection: MongoDB collection to add the texts to
218218
embedding: Text embedding model to use
219-
text_key: MongoDB field that will contain the text for each document
219+
text_key: MongoDB field that will contain the text for each document. It is possible to parse a list of fields.\
220+
The first one will be used as text key. Default: 'text'
220221
index_name: Existing Atlas Vector Search Index
221222
embedding_key: Field that will contain the embedding for each document
222223
relevance_score_fn: The similarity score used for the index
@@ -229,7 +230,7 @@ def __init__(
229230
self._collection = collection
230231
self._embedding = embedding
231232
self._index_name = index_name
232-
self._text_key = text_key
233+
self._text_key = text_key if isinstance(text_key, str) else text_key[0]
233234
self._embedding_key = embedding_key
234235
self._relevance_score_fn = relevance_score_fn
235236

Lines changed: 279 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,279 @@
1+
from time import sleep, time
2+
from typing import Generator, List
3+
4+
import pytest
5+
from langchain_core.documents import Document
6+
from langchain_core.embeddings import Embeddings
7+
from pymongo import MongoClient
8+
from pymongo.collection import Collection
9+
10+
from langchain_mongodb import MongoDBAtlasVectorSearch
11+
from langchain_mongodb.index import (
12+
create_fulltext_search_index,
13+
create_vector_search_index,
14+
)
15+
from langchain_mongodb.retrievers import (
16+
MongoDBAtlasFullTextSearchRetriever,
17+
MongoDBAtlasHybridSearchRetriever,
18+
)
19+
20+
from ..utils import DB_NAME, PatchedMongoDBAtlasVectorSearch
21+
22+
COLLECTION_NAME = "langchain_test_retrievers"
23+
COLLECTION_NAME_NESTED = "langchain_test_retrievers_nested"
24+
VECTOR_INDEX_NAME = "vector_index"
25+
EMBEDDING_FIELD = "embedding"
26+
PAGE_CONTENT_FIELD = ["text", "keywords"]
27+
PAGE_CONTENT_FIELD_NESTED = "title.text"
28+
SEARCH_INDEX_NAME = "text_index_multi"
29+
SEARCH_INDEX_NAME_NESTED = "text_index_nested"
30+
31+
TIMEOUT = 60.0
32+
INTERVAL = 0.5
33+
34+
35+
@pytest.fixture(scope="module")
36+
def example_documents() -> List[Document]:
37+
return [
38+
Document(
39+
page_content="In 2023, I visited Paris", metadata={"keywords": "MongoDB"}
40+
),
41+
Document(
42+
page_content="In 2022, I visited New York",
43+
metadata={"keywords": "Atlas"},
44+
),
45+
Document(
46+
page_content="In 2021, I visited New Orleans",
47+
metadata={"keywords": "Search"},
48+
),
49+
Document(
50+
page_content="Sandwiches are beautiful. Sandwiches are fine.",
51+
metadata={"keywords": "is awesome"},
52+
),
53+
]
54+
55+
56+
@pytest.fixture(scope="module")
57+
def collection(client: MongoClient, dimensions: int) -> Collection:
58+
"""A Collection with both a Vector and a Full-text Search Index"""
59+
if COLLECTION_NAME not in client[DB_NAME].list_collection_names():
60+
clxn = client[DB_NAME].create_collection(COLLECTION_NAME)
61+
else:
62+
clxn = client[DB_NAME][COLLECTION_NAME]
63+
64+
clxn.delete_many({})
65+
66+
if not any([VECTOR_INDEX_NAME == ix["name"] for ix in clxn.list_search_indexes()]):
67+
create_vector_search_index(
68+
collection=clxn,
69+
index_name=VECTOR_INDEX_NAME,
70+
dimensions=dimensions,
71+
path="embedding",
72+
similarity="cosine",
73+
wait_until_complete=TIMEOUT,
74+
)
75+
76+
if not any([SEARCH_INDEX_NAME == ix["name"] for ix in clxn.list_search_indexes()]):
77+
create_fulltext_search_index(
78+
collection=clxn,
79+
index_name=SEARCH_INDEX_NAME,
80+
field=PAGE_CONTENT_FIELD,
81+
wait_until_complete=TIMEOUT,
82+
)
83+
84+
return clxn
85+
86+
87+
@pytest.fixture(scope="module")
88+
def collection_nested(client: MongoClient, dimensions: int) -> Collection:
89+
"""A Collection with both a Vector and a Full-text Search Index"""
90+
if COLLECTION_NAME_NESTED not in client[DB_NAME].list_collection_names():
91+
clxn = client[DB_NAME].create_collection(COLLECTION_NAME_NESTED)
92+
else:
93+
clxn = client[DB_NAME][COLLECTION_NAME_NESTED]
94+
95+
clxn.delete_many({})
96+
97+
if not any([VECTOR_INDEX_NAME == ix["name"] for ix in clxn.list_search_indexes()]):
98+
create_vector_search_index(
99+
collection=clxn,
100+
index_name=VECTOR_INDEX_NAME,
101+
dimensions=dimensions,
102+
path="embedding",
103+
similarity="cosine",
104+
wait_until_complete=TIMEOUT,
105+
)
106+
107+
if not any(
108+
[SEARCH_INDEX_NAME_NESTED == ix["name"] for ix in clxn.list_search_indexes()]
109+
):
110+
create_fulltext_search_index(
111+
collection=clxn,
112+
index_name=SEARCH_INDEX_NAME_NESTED,
113+
field=PAGE_CONTENT_FIELD_NESTED,
114+
wait_until_complete=TIMEOUT,
115+
)
116+
117+
return clxn
118+
119+
120+
@pytest.fixture(scope="module")
121+
def indexed_vectorstore(
122+
collection: Collection,
123+
example_documents: List[Document],
124+
embedding: Embeddings,
125+
) -> Generator[MongoDBAtlasVectorSearch, None, None]:
126+
"""Return a VectorStore with example document embeddings indexed."""
127+
128+
vectorstore = PatchedMongoDBAtlasVectorSearch(
129+
collection=collection,
130+
embedding=embedding,
131+
index_name=VECTOR_INDEX_NAME,
132+
text_key=PAGE_CONTENT_FIELD,
133+
)
134+
135+
vectorstore.add_documents(example_documents)
136+
137+
yield vectorstore
138+
139+
vectorstore.collection.delete_many({})
140+
141+
142+
@pytest.fixture(scope="module")
143+
def indexed_nested_vectorstore(
144+
collection_nested: Collection,
145+
example_documents: List[Document],
146+
embedding: Embeddings,
147+
) -> Generator[MongoDBAtlasVectorSearch, None, None]:
148+
"""Return a VectorStore with example document embeddings indexed."""
149+
150+
vectorstore = PatchedMongoDBAtlasVectorSearch(
151+
collection=collection_nested,
152+
embedding=embedding,
153+
index_name=VECTOR_INDEX_NAME,
154+
text_key=PAGE_CONTENT_FIELD_NESTED,
155+
)
156+
157+
vectorstore.add_documents(example_documents)
158+
159+
yield vectorstore
160+
161+
vectorstore.collection.delete_many({})
162+
163+
164+
def test_vector_retriever(indexed_vectorstore: PatchedMongoDBAtlasVectorSearch) -> None:
165+
"""Test VectorStoreRetriever"""
166+
retriever = indexed_vectorstore.as_retriever()
167+
168+
query1 = "When did I visit France?"
169+
results = retriever.invoke(query1)
170+
assert len(results) == 4
171+
assert "Paris" in results[0].page_content
172+
assert "MongoDB" == results[0].metadata["keywords"]
173+
174+
query2 = "When was the last time I visited new orleans?"
175+
results = retriever.invoke(query2)
176+
assert "New Orleans" in results[0].page_content
177+
assert "Search" == results[0].metadata["keywords"]
178+
179+
180+
def test_hybrid_retriever(indexed_vectorstore: PatchedMongoDBAtlasVectorSearch) -> None:
181+
"""Test basic usage of MongoDBAtlasHybridSearchRetriever"""
182+
183+
retriever = MongoDBAtlasHybridSearchRetriever(
184+
vectorstore=indexed_vectorstore,
185+
search_index_name=SEARCH_INDEX_NAME,
186+
k=3,
187+
)
188+
189+
query1 = "When did I visit France?"
190+
results = retriever.invoke(query1)
191+
assert len(results) == 3
192+
assert "Paris" in results[0].page_content
193+
194+
query2 = "When was the last time I visited new orleans?"
195+
results = retriever.invoke(query2)
196+
assert "New Orleans" in results[0].page_content
197+
198+
199+
def test_hybrid_retriever_deprecated_top_k(
200+
indexed_vectorstore: PatchedMongoDBAtlasVectorSearch,
201+
) -> None:
202+
"""Test basic usage of MongoDBAtlasHybridSearchRetriever"""
203+
retriever = MongoDBAtlasHybridSearchRetriever(
204+
vectorstore=indexed_vectorstore,
205+
search_index_name=SEARCH_INDEX_NAME,
206+
top_k=3,
207+
)
208+
209+
query1 = "When did I visit France?"
210+
results = retriever.invoke(query1)
211+
assert len(results) == 3
212+
assert "Paris" in results[0].page_content
213+
214+
query2 = "When was the last time I visited new orleans?"
215+
results = retriever.invoke(query2)
216+
assert "New Orleans" in results[0].page_content
217+
218+
219+
def test_hybrid_retriever_nested(
220+
indexed_nested_vectorstore: PatchedMongoDBAtlasVectorSearch,
221+
) -> None:
222+
"""Test basic usage of MongoDBAtlasHybridSearchRetriever"""
223+
retriever = MongoDBAtlasHybridSearchRetriever(
224+
vectorstore=indexed_nested_vectorstore,
225+
search_index_name=SEARCH_INDEX_NAME_NESTED,
226+
k=3,
227+
)
228+
229+
query1 = "What did I visit France?"
230+
results = retriever.invoke(query1)
231+
assert len(results) == 3
232+
assert "Paris" in results[0].page_content
233+
234+
query2 = "When was the last time I visited new orleans?"
235+
results = retriever.invoke(query2)
236+
assert "New Orleans" in results[0].page_content
237+
238+
239+
def test_fulltext_retriever(
240+
indexed_vectorstore: PatchedMongoDBAtlasVectorSearch,
241+
) -> None:
242+
"""Test result of performing fulltext search.
243+
244+
The Retriever is independent of the VectorStore.
245+
We use it here only to get the Collection, which we know to be indexed.
246+
"""
247+
248+
collection: Collection = indexed_vectorstore.collection
249+
250+
retriever = MongoDBAtlasFullTextSearchRetriever(
251+
collection=collection,
252+
search_index_name=SEARCH_INDEX_NAME,
253+
search_field=PAGE_CONTENT_FIELD,
254+
)
255+
256+
# Wait for the search index to complete.
257+
search_content = dict(
258+
index=SEARCH_INDEX_NAME,
259+
wildcard=dict(query="*", path=PAGE_CONTENT_FIELD, allowAnalyzedField=True),
260+
)
261+
n_docs = collection.count_documents({})
262+
t0 = time()
263+
while True:
264+
if (time() - t0) > TIMEOUT:
265+
raise TimeoutError(
266+
f"Search index {SEARCH_INDEX_NAME} did not complete in {TIMEOUT}"
267+
)
268+
cursor = collection.aggregate([{"$search": search_content}])
269+
if len(list(cursor)) == n_docs:
270+
break
271+
sleep(INTERVAL)
272+
273+
query = "What is MongoDB"
274+
results = retriever.invoke(query)
275+
print(results)
276+
print(list(collection.list_search_indexes()))
277+
# assert "New Orleans" in results[0].page_content
278+
assert "MongoDB" in results[0].metadata["keywords"]
279+
assert "score" in results[0].metadata

0 commit comments

Comments
 (0)