11import os
22
33import pytest
4+ from flaky import flaky
45from langchain_core .documents import Document
56from langchain_core .language_models .chat_models import BaseChatModel
67from langchain_core .messages import AIMessage
@@ -117,7 +118,10 @@ def entity_example():
117118@pytest .fixture (scope = "module" )
118119def graph_store (collection , entity_extraction_model , documents ) -> MongoDBGraphStore :
119120 store = MongoDBGraphStore (
120- collection , entity_extraction_model , entity_prompt , query_prompt
121+ collection = collection ,
122+ entity_extraction_model = entity_extraction_model ,
123+ entity_prompt = entity_prompt ,
124+ query_prompt = query_prompt ,
121125 )
122126 bulkwrite_results = store .add_documents (documents )
123127 assert len (bulkwrite_results ) == len (documents )
@@ -136,6 +140,7 @@ def test_add_docs_store(graph_store):
136140 assert 4 <= len (extracted_entities ) < 8
137141
138142
143+ @flaky
139144def test_extract_entity_names (graph_store , query_connection ):
140145 query_entity_names = graph_store .extract_entity_names (query_connection )
141146 assert set (query_entity_names ) == {"John Doe" , "Jane Smith" }
@@ -145,6 +150,7 @@ def test_extract_entity_names(graph_store, query_connection):
145150 assert len (no_names ) == 0
146151
147152
153+ @flaky
148154def test_related_entities (graph_store ):
149155 entity_names = ["John Doe" , "Jane Smith" ]
150156 related_entities = graph_store .related_entities (entity_names )
@@ -155,29 +161,39 @@ def test_related_entities(graph_store):
155161 assert len (no_entities ) == 0
156162
157163
164+ @flaky
158165def test_additional_entity_examples (entity_extraction_model , entity_example , documents ):
159- # Test additional examples
166+ # First, create one client just to drop any existing collections
160167 client = MongoClient (MONGODB_URI )
161- db = client [DB_NAME ]
162168 clxn_name = f"{ COLLECTION_NAME } _addl_examples"
163- db [clxn_name ].drop ()
164- collection = db . create_collection ( clxn_name )
169+ client [ DB_NAME ] [clxn_name ].drop ()
170+ # Test additional examples
165171 store_with_addl_examples = MongoDBGraphStore (
166- collection , entity_extraction_model , entity_examples = entity_example
172+ connection_string = MONGODB_URI ,
173+ database_name = DB_NAME ,
174+ collection_name = clxn_name ,
175+ entity_extraction_model = entity_extraction_model ,
176+ entity_prompt = entity_prompt ,
177+ query_prompt = query_prompt ,
178+ entity_examples = entity_example ,
167179 )
180+ store_with_addl_examples .collection .drop ()
181+
168182 store_with_addl_examples .add_documents (documents )
169183 entity_names = ["ACME Corporation" , "GreenTech Ltd." ]
170184 new_entities = store_with_addl_examples .related_entities (entity_names )
171185 assert len (new_entities ) >= 2
172186
173187
188+ @flaky
174189def test_chat_response (graph_store , query_connection ):
175190 """Displays querying an existing Knowledge Graph Database"""
176191 answer = graph_store .chat_response (query_connection )
177192 assert isinstance (answer , AIMessage )
178193 assert "acme corporation" in answer .content .lower ()
179194
180195
196+ @flaky
181197def test_similarity_search (graph_store , query_connection ):
182198 docs = graph_store .similarity_search (query_connection )
183199 assert len (docs ) >= 4
@@ -186,33 +202,78 @@ def test_similarity_search(graph_store, query_connection):
186202 assert any ("attributes" in d .keys () for d in docs )
187203
188204
205+ @flaky
189206def test_validator (documents , entity_extraction_model ):
207+ # Case 1. No existing collection.
190208 client = MongoClient (MONGODB_URI )
191- clxn_name = "langchain_test_graphrag_validation "
209+ clxn_name = f" { COLLECTION_NAME } _validation "
192210 client [DB_NAME ][clxn_name ].drop ()
193- clxn = client [ DB_NAME ]. create_collection ( clxn_name )
211+ # now we call with validation that can be added without db admin privileges
194212 store = MongoDBGraphStore (
195- clxn , entity_extraction_model , validate = True , validation_action = "error"
213+ connection_string = MONGODB_URI ,
214+ database_name = DB_NAME ,
215+ collection_name = clxn_name ,
216+ entity_extraction_model = entity_extraction_model ,
217+ validate = True ,
218+ validation_action = "error" ,
196219 )
197220 bulkwrite_results = store .add_documents (documents )
198221 assert len (bulkwrite_results ) == len (documents )
199222 entities = store .collection .find ({}).to_list ()
200223 # Using subset because SolarGrid Initiative is not always considered an entity
201224 assert {"Person" , "Organization" }.issubset (set (e ["type" ] for e in entities ))
225+ client .close ()
202226
227+ # Case 2: Existing collection with a validator
228+ client = MongoClient (MONGODB_URI )
229+ clxn_name = f"{ COLLECTION_NAME } _validation"
230+ collection = client [DB_NAME ][clxn_name ]
231+ collection .delete_many ({})
203232
233+ store = MongoDBGraphStore (
234+ collection = collection ,
235+ entity_extraction_model = entity_extraction_model ,
236+ validate = True ,
237+ validation_action = "error" ,
238+ )
239+ bulkwrite_results = store .add_documents (documents )
240+ assert len (bulkwrite_results ) == len (documents )
241+ collection .drop ()
242+ client .close ()
243+
244+ # Case 3: Existing collection without a validator
245+ client = MongoClient (MONGODB_URI )
246+ clxn_name = f"{ COLLECTION_NAME } _validation"
247+ collection = client [DB_NAME ].create_collection (clxn_name )
248+ store = MongoDBGraphStore (
249+ collection = collection ,
250+ entity_extraction_model = entity_extraction_model ,
251+ validate = True ,
252+ validation_action = "error" ,
253+ )
254+ bulkwrite_results = store .add_documents (documents )
255+ assert len (bulkwrite_results ) == len (documents )
256+ client .close ()
257+
258+
259+ @flaky
204260def test_allowed_entity_types (documents , entity_extraction_model ):
205261 """Add allowed_entity_types. Use the validator to confirm behaviour."""
206262 allowed_entity_types = ["Person" ]
263+ # drop collection
207264 client = MongoClient (MONGODB_URI )
208265 collection_name = f"{ COLLECTION_NAME } _allowed_entity_types"
209266 client [DB_NAME ][collection_name ].drop ()
210- collection = client [DB_NAME ].create_collection (collection_name )
267+ # create knowledge graph with only allowed_entity_types
268+ # this changes the schema at runtime
211269 store = MongoDBGraphStore (
212- collection ,
213- entity_extraction_model ,
214270 allowed_entity_types = allowed_entity_types ,
215271 validate = True ,
272+ validation_action = "error" ,
273+ connection_string = MONGODB_URI ,
274+ database_name = DB_NAME ,
275+ collection_name = collection_name ,
276+ entity_extraction_model = entity_extraction_model ,
216277 )
217278 bulkwrite_results = store .add_documents (documents )
218279 assert len (bulkwrite_results ) == len (documents )
@@ -223,16 +284,24 @@ def test_allowed_entity_types(documents, entity_extraction_model):
223284 all ([len (e ["relationships" ].get ("attributes" , [])) == 0 for e in entities ])
224285
225286
287+ @flaky
226288def test_allowed_relationship_types (documents , entity_extraction_model ):
289+ # drop collection
227290 client = MongoClient (MONGODB_URI )
228- collection_name = f"{ COLLECTION_NAME } _allowed_relationship_types"
229- client [DB_NAME ][collection_name ].drop ()
230- collection = client [DB_NAME ].create_collection (collection_name )
291+ clxn_name = f"{ COLLECTION_NAME } _allowed_relationship_types"
292+ client [DB_NAME ][clxn_name ].drop ()
293+ collection = client [DB_NAME ].create_collection (clxn_name )
294+ collection .drop ()
295+ # create knowledge graph with only allowed_relationship_types=["partner"]
296+ # this changes the schema at runtime
231297 store = MongoDBGraphStore (
232- collection ,
233- entity_extraction_model ,
234298 allowed_relationship_types = ["partner" ],
235299 validate = True ,
300+ validation_action = "error" ,
301+ connection_string = MONGODB_URI ,
302+ database_name = DB_NAME ,
303+ collection_name = clxn_name ,
304+ entity_extraction_model = entity_extraction_model ,
236305 )
237306 bulkwrite_results = store .add_documents (documents )
238307 assert len (bulkwrite_results ) == len (documents )
0 commit comments