|
| 1 | +from time import sleep, time |
| 2 | +from typing import Generator, List |
| 3 | + |
| 4 | +import pytest |
| 5 | +from langchain_core.documents import Document |
| 6 | +from langchain_core.embeddings import Embeddings |
| 7 | +from pymongo import MongoClient |
| 8 | +from pymongo.collection import Collection |
| 9 | + |
| 10 | +from langchain_mongodb import MongoDBAtlasVectorSearch |
| 11 | +from langchain_mongodb.index import ( |
| 12 | + create_fulltext_search_index, |
| 13 | + create_vector_search_index, |
| 14 | +) |
| 15 | +from langchain_mongodb.retrievers import ( |
| 16 | + MongoDBAtlasFullTextSearchRetriever, |
| 17 | + MongoDBAtlasHybridSearchRetriever, |
| 18 | +) |
| 19 | + |
| 20 | +from ..utils import DB_NAME, PatchedMongoDBAtlasVectorSearch |
| 21 | + |
| 22 | +COLLECTION_NAME = "langchain_test_retrievers" |
| 23 | +COLLECTION_NAME_NESTED = "langchain_test_retrievers_nested" |
| 24 | +VECTOR_INDEX_NAME = "vector_index" |
| 25 | +EMBEDDING_FIELD = "embedding" |
| 26 | +PAGE_CONTENT_FIELD = ["text", "keywords"] |
| 27 | +PAGE_CONTENT_FIELD_NESTED = "title.text" |
| 28 | +SEARCH_INDEX_NAME = "text_index_multi" |
| 29 | +SEARCH_INDEX_NAME_NESTED = "text_index_nested" |
| 30 | + |
| 31 | +TIMEOUT = 60.0 |
| 32 | +INTERVAL = 0.5 |
| 33 | + |
| 34 | + |
| 35 | +@pytest.fixture(scope="module") |
| 36 | +def example_documents() -> List[Document]: |
| 37 | + return [ |
| 38 | + Document( |
| 39 | + page_content="In 2023, I visited Paris", metadata={"keywords": "MongoDB"} |
| 40 | + ), |
| 41 | + Document( |
| 42 | + page_content="In 2022, I visited New York", |
| 43 | + metadata={"keywords": "Atlas"}, |
| 44 | + ), |
| 45 | + Document( |
| 46 | + page_content="In 2021, I visited New Orleans", |
| 47 | + metadata={"keywords": "Search"}, |
| 48 | + ), |
| 49 | + Document( |
| 50 | + page_content="Sandwiches are beautiful. Sandwiches are fine.", |
| 51 | + metadata={"keywords": "is awesome"}, |
| 52 | + ), |
| 53 | + ] |
| 54 | + |
| 55 | + |
| 56 | +@pytest.fixture(scope="module") |
| 57 | +def collection(client: MongoClient, dimensions: int) -> Collection: |
| 58 | + """A Collection with both a Vector and a Full-text Search Index""" |
| 59 | + if COLLECTION_NAME not in client[DB_NAME].list_collection_names(): |
| 60 | + clxn = client[DB_NAME].create_collection(COLLECTION_NAME) |
| 61 | + else: |
| 62 | + clxn = client[DB_NAME][COLLECTION_NAME] |
| 63 | + |
| 64 | + clxn.delete_many({}) |
| 65 | + |
| 66 | + if not any([VECTOR_INDEX_NAME == ix["name"] for ix in clxn.list_search_indexes()]): |
| 67 | + create_vector_search_index( |
| 68 | + collection=clxn, |
| 69 | + index_name=VECTOR_INDEX_NAME, |
| 70 | + dimensions=dimensions, |
| 71 | + path="embedding", |
| 72 | + similarity="cosine", |
| 73 | + wait_until_complete=TIMEOUT, |
| 74 | + ) |
| 75 | + |
| 76 | + if not any([SEARCH_INDEX_NAME == ix["name"] for ix in clxn.list_search_indexes()]): |
| 77 | + create_fulltext_search_index( |
| 78 | + collection=clxn, |
| 79 | + index_name=SEARCH_INDEX_NAME, |
| 80 | + field=PAGE_CONTENT_FIELD, |
| 81 | + wait_until_complete=TIMEOUT, |
| 82 | + ) |
| 83 | + |
| 84 | + return clxn |
| 85 | + |
| 86 | + |
| 87 | +@pytest.fixture(scope="module") |
| 88 | +def collection_nested(client: MongoClient, dimensions: int) -> Collection: |
| 89 | + """A Collection with both a Vector and a Full-text Search Index""" |
| 90 | + if COLLECTION_NAME_NESTED not in client[DB_NAME].list_collection_names(): |
| 91 | + clxn = client[DB_NAME].create_collection(COLLECTION_NAME_NESTED) |
| 92 | + else: |
| 93 | + clxn = client[DB_NAME][COLLECTION_NAME_NESTED] |
| 94 | + |
| 95 | + clxn.delete_many({}) |
| 96 | + |
| 97 | + if not any([VECTOR_INDEX_NAME == ix["name"] for ix in clxn.list_search_indexes()]): |
| 98 | + create_vector_search_index( |
| 99 | + collection=clxn, |
| 100 | + index_name=VECTOR_INDEX_NAME, |
| 101 | + dimensions=dimensions, |
| 102 | + path="embedding", |
| 103 | + similarity="cosine", |
| 104 | + wait_until_complete=TIMEOUT, |
| 105 | + ) |
| 106 | + |
| 107 | + if not any( |
| 108 | + [SEARCH_INDEX_NAME_NESTED == ix["name"] for ix in clxn.list_search_indexes()] |
| 109 | + ): |
| 110 | + create_fulltext_search_index( |
| 111 | + collection=clxn, |
| 112 | + index_name=SEARCH_INDEX_NAME_NESTED, |
| 113 | + field=PAGE_CONTENT_FIELD_NESTED, |
| 114 | + wait_until_complete=TIMEOUT, |
| 115 | + ) |
| 116 | + |
| 117 | + return clxn |
| 118 | + |
| 119 | + |
| 120 | +@pytest.fixture(scope="module") |
| 121 | +def indexed_vectorstore( |
| 122 | + collection: Collection, |
| 123 | + example_documents: List[Document], |
| 124 | + embedding: Embeddings, |
| 125 | +) -> Generator[MongoDBAtlasVectorSearch, None, None]: |
| 126 | + """Return a VectorStore with example document embeddings indexed.""" |
| 127 | + |
| 128 | + vectorstore = PatchedMongoDBAtlasVectorSearch( |
| 129 | + collection=collection, |
| 130 | + embedding=embedding, |
| 131 | + index_name=VECTOR_INDEX_NAME, |
| 132 | + text_key=PAGE_CONTENT_FIELD, |
| 133 | + ) |
| 134 | + |
| 135 | + vectorstore.add_documents(example_documents) |
| 136 | + |
| 137 | + yield vectorstore |
| 138 | + |
| 139 | + vectorstore.collection.delete_many({}) |
| 140 | + |
| 141 | + |
| 142 | +@pytest.fixture(scope="module") |
| 143 | +def indexed_nested_vectorstore( |
| 144 | + collection_nested: Collection, |
| 145 | + example_documents: List[Document], |
| 146 | + embedding: Embeddings, |
| 147 | +) -> Generator[MongoDBAtlasVectorSearch, None, None]: |
| 148 | + """Return a VectorStore with example document embeddings indexed.""" |
| 149 | + |
| 150 | + vectorstore = PatchedMongoDBAtlasVectorSearch( |
| 151 | + collection=collection_nested, |
| 152 | + embedding=embedding, |
| 153 | + index_name=VECTOR_INDEX_NAME, |
| 154 | + text_key=PAGE_CONTENT_FIELD_NESTED, |
| 155 | + ) |
| 156 | + |
| 157 | + vectorstore.add_documents(example_documents) |
| 158 | + |
| 159 | + yield vectorstore |
| 160 | + |
| 161 | + vectorstore.collection.delete_many({}) |
| 162 | + |
| 163 | + |
| 164 | +def test_vector_retriever(indexed_vectorstore: PatchedMongoDBAtlasVectorSearch) -> None: |
| 165 | + """Test VectorStoreRetriever""" |
| 166 | + retriever = indexed_vectorstore.as_retriever() |
| 167 | + |
| 168 | + query1 = "When did I visit France?" |
| 169 | + results = retriever.invoke(query1) |
| 170 | + assert len(results) == 4 |
| 171 | + assert "Paris" in results[0].page_content |
| 172 | + assert "MongoDB" == results[0].metadata["keywords"] |
| 173 | + |
| 174 | + query2 = "When was the last time I visited new orleans?" |
| 175 | + results = retriever.invoke(query2) |
| 176 | + assert "New Orleans" in results[0].page_content |
| 177 | + assert "Search" == results[0].metadata["keywords"] |
| 178 | + |
| 179 | + |
| 180 | +def test_hybrid_retriever(indexed_vectorstore: PatchedMongoDBAtlasVectorSearch) -> None: |
| 181 | + """Test basic usage of MongoDBAtlasHybridSearchRetriever""" |
| 182 | + |
| 183 | + retriever = MongoDBAtlasHybridSearchRetriever( |
| 184 | + vectorstore=indexed_vectorstore, |
| 185 | + search_index_name=SEARCH_INDEX_NAME, |
| 186 | + k=3, |
| 187 | + ) |
| 188 | + |
| 189 | + query1 = "When did I visit France?" |
| 190 | + results = retriever.invoke(query1) |
| 191 | + assert len(results) == 3 |
| 192 | + assert "Paris" in results[0].page_content |
| 193 | + |
| 194 | + query2 = "When was the last time I visited new orleans?" |
| 195 | + results = retriever.invoke(query2) |
| 196 | + assert "New Orleans" in results[0].page_content |
| 197 | + |
| 198 | + |
| 199 | +def test_hybrid_retriever_deprecated_top_k( |
| 200 | + indexed_vectorstore: PatchedMongoDBAtlasVectorSearch, |
| 201 | +) -> None: |
| 202 | + """Test basic usage of MongoDBAtlasHybridSearchRetriever""" |
| 203 | + retriever = MongoDBAtlasHybridSearchRetriever( |
| 204 | + vectorstore=indexed_vectorstore, |
| 205 | + search_index_name=SEARCH_INDEX_NAME, |
| 206 | + top_k=3, |
| 207 | + ) |
| 208 | + |
| 209 | + query1 = "When did I visit France?" |
| 210 | + results = retriever.invoke(query1) |
| 211 | + assert len(results) == 3 |
| 212 | + assert "Paris" in results[0].page_content |
| 213 | + |
| 214 | + query2 = "When was the last time I visited new orleans?" |
| 215 | + results = retriever.invoke(query2) |
| 216 | + assert "New Orleans" in results[0].page_content |
| 217 | + |
| 218 | + |
| 219 | +def test_hybrid_retriever_nested( |
| 220 | + indexed_nested_vectorstore: PatchedMongoDBAtlasVectorSearch, |
| 221 | +) -> None: |
| 222 | + """Test basic usage of MongoDBAtlasHybridSearchRetriever""" |
| 223 | + retriever = MongoDBAtlasHybridSearchRetriever( |
| 224 | + vectorstore=indexed_nested_vectorstore, |
| 225 | + search_index_name=SEARCH_INDEX_NAME_NESTED, |
| 226 | + k=3, |
| 227 | + ) |
| 228 | + |
| 229 | + query1 = "What did I visit France?" |
| 230 | + results = retriever.invoke(query1) |
| 231 | + assert len(results) == 3 |
| 232 | + assert "Paris" in results[0].page_content |
| 233 | + |
| 234 | + query2 = "When was the last time I visited new orleans?" |
| 235 | + results = retriever.invoke(query2) |
| 236 | + assert "New Orleans" in results[0].page_content |
| 237 | + |
| 238 | + |
| 239 | +def test_fulltext_retriever( |
| 240 | + indexed_vectorstore: PatchedMongoDBAtlasVectorSearch, |
| 241 | +) -> None: |
| 242 | + """Test result of performing fulltext search. |
| 243 | +
|
| 244 | + The Retriever is independent of the VectorStore. |
| 245 | + We use it here only to get the Collection, which we know to be indexed. |
| 246 | + """ |
| 247 | + |
| 248 | + collection: Collection = indexed_vectorstore.collection |
| 249 | + |
| 250 | + retriever = MongoDBAtlasFullTextSearchRetriever( |
| 251 | + collection=collection, |
| 252 | + search_index_name=SEARCH_INDEX_NAME, |
| 253 | + search_field=PAGE_CONTENT_FIELD, |
| 254 | + ) |
| 255 | + |
| 256 | + # Wait for the search index to complete. |
| 257 | + search_content = dict( |
| 258 | + index=SEARCH_INDEX_NAME, |
| 259 | + wildcard=dict(query="*", path=PAGE_CONTENT_FIELD, allowAnalyzedField=True), |
| 260 | + ) |
| 261 | + n_docs = collection.count_documents({}) |
| 262 | + t0 = time() |
| 263 | + while True: |
| 264 | + if (time() - t0) > TIMEOUT: |
| 265 | + raise TimeoutError( |
| 266 | + f"Search index {SEARCH_INDEX_NAME} did not complete in {TIMEOUT}" |
| 267 | + ) |
| 268 | + cursor = collection.aggregate([{"$search": search_content}]) |
| 269 | + if len(list(cursor)) == n_docs: |
| 270 | + break |
| 271 | + sleep(INTERVAL) |
| 272 | + |
| 273 | + query = "What is MongoDB" |
| 274 | + results = retriever.invoke(query) |
| 275 | + print(results) |
| 276 | + print(list(collection.list_search_indexes())) |
| 277 | + # assert "New Orleans" in results[0].page_content |
| 278 | + assert "MongoDB" in results[0].metadata["keywords"] |
| 279 | + assert "score" in results[0].metadata |
0 commit comments