Skip to content

Commit 964277e

Browse files
Updates MongoDBGraphStore.__init__ signature (#83)
Also hardens validation.
1 parent ccee700 commit 964277e

File tree

5 files changed

+165
-35
lines changed

5 files changed

+165
-35
lines changed

libs/langchain-mongodb/CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
---
44

5-
## Changes in version 0.5 (2025/02/11)
5+
## Changes in version 0.5 (2025/02/25)
66

77
- Added GraphRAG support via `MongoDBGraphStore`
88

libs/langchain-mongodb/langchain_mongodb/graphrag/graph.py

Lines changed: 66 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -87,12 +87,15 @@ class MongoDBGraphStore:
8787
from completely different sources.
8888
- "Jane Smith works with John Doe."
8989
- "Jane Smith works at MongoDB."
90-
9190
"""
9291

9392
def __init__(
9493
self,
95-
collection: Collection,
94+
*,
95+
connection_string: Optional[str] = None,
96+
database_name: Optional[str] = None,
97+
collection_name: Optional[str] = None,
98+
collection: Optional[Collection] = None,
9699
entity_extraction_model: BaseChatModel,
97100
entity_prompt: ChatPromptTemplate = None,
98101
query_prompt: ChatPromptTemplate = None,
@@ -106,7 +109,11 @@ def __init__(
106109
):
107110
"""
108111
Args:
109-
collection: Collection representing an Entity Graph.
112+
connection_string: A valid MongoDB connection URI.
113+
database_name: The name of the database to connect to.
114+
collection_name: The name of the collection to connect to.
115+
collection: A Collection that will represent a Knowledge Graph.
116+
** One may pass a Collection in lieu of connection_string, database_name, and collection_name.
110117
entity_extraction_model: LLM for converting documents into Graph of Entities and Relationships.
111118
entity_prompt: Prompt to fill graph store with entities following schema.
112119
Defaults to .prompts.ENTITY_EXTRACTION_INSTRUCTIONS
@@ -122,6 +129,62 @@ def __init__(
122129
- If "warn", the default, documents will be inserted but errors logged.
123130
- If "error", an exception will be raised if any document does not match the schema.
124131
"""
132+
self._schema = deepcopy(entity_schema)
133+
collection_existed = True
134+
if connection_string and collection is not None:
135+
raise ValueError(
136+
"Pass one of: connection_string, database_name, and collection_name"
137+
"OR a MongoDB Collection."
138+
)
139+
if collection is None: # collection is specified by uri and names
140+
client: MongoClient = MongoClient(
141+
connection_string,
142+
driver=DriverInfo(
143+
name="Langchain", version=version("langchain-mongodb")
144+
),
145+
)
146+
db = client[database_name]
147+
if collection_name not in db.list_collection_names():
148+
validator = {"$jsonSchema": self._schema} if validate else None
149+
collection = client[database_name].create_collection(
150+
collection_name,
151+
validator=validator,
152+
validationAction=validation_action,
153+
)
154+
collection_existed = False
155+
else:
156+
collection = db[collection_name]
157+
else:
158+
if not isinstance(collection, Collection):
159+
raise ValueError(
160+
"collection must be a MongoDB Collection. "
161+
"Consider using connection_string, database_name, and collection_name."
162+
)
163+
164+
if validate and collection_existed:
165+
# first check for existing validator
166+
collection_info = collection.database.command(
167+
"listCollections", filter={"name": collection.name}
168+
)
169+
collection_options = collection_info.get("cursor", {}).get("firstBatch", [])
170+
validator = collection_options[0].get("options", {}).get("validator", None)
171+
if not validator:
172+
try:
173+
collection.database.command(
174+
"collMod",
175+
collection.name,
176+
validator={"$jsonSchema": self._schema},
177+
validationAction=validation_action,
178+
)
179+
except OperationFailure:
180+
logger.warning(
181+
"Validation will NOT be performed. "
182+
"User must be DB Admin to add validation **after** a Collection is created. \n"
183+
"Please add validator when you create collection: "
184+
"db.create_collection.(coll_name, validator={'$jsonSchema': schema.entity_schema})"
185+
)
186+
self.collection = collection
187+
125188
self.entity_extraction_model = entity_extraction_model
126189
self.entity_prompt = (
127190
prompts.entity_prompt if entity_prompt is None else entity_prompt
@@ -145,20 +208,6 @@ def __init__(
145208
] = allowed_relationship_types
146209
else:
147210
self.allowed_relationship_types = []
148-
if validate:
149-
try:
150-
collection.database.command(
151-
"collMod",
152-
collection.name,
153-
validator={"$jsonSchema": self._schema},
154-
validationAction=validation_action,
155-
)
156-
except OperationFailure:
157-
logger.warning(
158-
"Validation will NOT be performed. User must be DB Admin to add validation **after** a Collection is created. \n"
159-
"Please add validator when you create collection: db.create_collection.(coll_name, validator={'$jsonSchema': self._schema})"
160-
)
161-
self.collection = collection
162211

163212
# Include examples
164213
if entity_examples is None:

libs/langchain-mongodb/pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ dev = [
3636
"langchain-openai>=0.2.14",
3737
"langchain-community>=0.3.14",
3838
"pypdf>=5.0.1",
39+
"flaky>=3.8.1",
3940
]
4041

4142
[tool.pytest.ini_options]

libs/langchain-mongodb/tests/integration_tests/test_graphrag.py

Lines changed: 86 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os
22

33
import pytest
4+
from flaky import flaky
45
from langchain_core.documents import Document
56
from langchain_core.language_models.chat_models import BaseChatModel
67
from langchain_core.messages import AIMessage
@@ -117,7 +118,10 @@ def entity_example():
117118
@pytest.fixture(scope="module")
118119
def graph_store(collection, entity_extraction_model, documents) -> MongoDBGraphStore:
119120
store = MongoDBGraphStore(
120-
collection, entity_extraction_model, entity_prompt, query_prompt
121+
collection=collection,
122+
entity_extraction_model=entity_extraction_model,
123+
entity_prompt=entity_prompt,
124+
query_prompt=query_prompt,
121125
)
122126
bulkwrite_results = store.add_documents(documents)
123127
assert len(bulkwrite_results) == len(documents)
@@ -136,6 +140,7 @@ def test_add_docs_store(graph_store):
136140
assert 4 <= len(extracted_entities) < 8
137141

138142

143+
@flaky
139144
def test_extract_entity_names(graph_store, query_connection):
140145
query_entity_names = graph_store.extract_entity_names(query_connection)
141146
assert set(query_entity_names) == {"John Doe", "Jane Smith"}
@@ -145,6 +150,7 @@ def test_extract_entity_names(graph_store, query_connection):
145150
assert len(no_names) == 0
146151

147152

153+
@flaky
148154
def test_related_entities(graph_store):
149155
entity_names = ["John Doe", "Jane Smith"]
150156
related_entities = graph_store.related_entities(entity_names)
@@ -155,29 +161,39 @@ def test_related_entities(graph_store):
155161
assert len(no_entities) == 0
156162

157163

164+
@flaky
158165
def test_additional_entity_examples(entity_extraction_model, entity_example, documents):
159-
# Test additional examples
166+
# First, create one client just to drop any existing collections
160167
client = MongoClient(MONGODB_URI)
161-
db = client[DB_NAME]
162168
clxn_name = f"{COLLECTION_NAME}_addl_examples"
163-
db[clxn_name].drop()
164-
collection = db.create_collection(clxn_name)
169+
client[DB_NAME][clxn_name].drop()
170+
# Test additional examples
165171
store_with_addl_examples = MongoDBGraphStore(
166-
collection, entity_extraction_model, entity_examples=entity_example
172+
connection_string=MONGODB_URI,
173+
database_name=DB_NAME,
174+
collection_name=clxn_name,
175+
entity_extraction_model=entity_extraction_model,
176+
entity_prompt=entity_prompt,
177+
query_prompt=query_prompt,
178+
entity_examples=entity_example,
167179
)
180+
store_with_addl_examples.collection.drop()
181+
168182
store_with_addl_examples.add_documents(documents)
169183
entity_names = ["ACME Corporation", "GreenTech Ltd."]
170184
new_entities = store_with_addl_examples.related_entities(entity_names)
171185
assert len(new_entities) >= 2
172186

173187

188+
@flaky
174189
def test_chat_response(graph_store, query_connection):
175190
"""Displays querying an existing Knowledge Graph Database"""
176191
answer = graph_store.chat_response(query_connection)
177192
assert isinstance(answer, AIMessage)
178193
assert "acme corporation" in answer.content.lower()
179194

180195

196+
@flaky
181197
def test_similarity_search(graph_store, query_connection):
182198
docs = graph_store.similarity_search(query_connection)
183199
assert len(docs) >= 4
@@ -186,33 +202,78 @@ def test_similarity_search(graph_store, query_connection):
186202
assert any("attributes" in d.keys() for d in docs)
187203

188204

205+
@flaky
189206
def test_validator(documents, entity_extraction_model):
207+
# Case 1. No existing collection.
190208
client = MongoClient(MONGODB_URI)
191-
clxn_name = "langchain_test_graphrag_validation"
209+
clxn_name = f"{COLLECTION_NAME}_validation"
192210
client[DB_NAME][clxn_name].drop()
193-
clxn = client[DB_NAME].create_collection(clxn_name)
211+
# now we call with validation that can be added without db admin privileges
194212
store = MongoDBGraphStore(
195-
clxn, entity_extraction_model, validate=True, validation_action="error"
213+
connection_string=MONGODB_URI,
214+
database_name=DB_NAME,
215+
collection_name=clxn_name,
216+
entity_extraction_model=entity_extraction_model,
217+
validate=True,
218+
validation_action="error",
196219
)
197220
bulkwrite_results = store.add_documents(documents)
198221
assert len(bulkwrite_results) == len(documents)
199222
entities = store.collection.find({}).to_list()
200223
# Using subset because SolarGrid Initiative is not always considered an entity
201224
assert {"Person", "Organization"}.issubset(set(e["type"] for e in entities))
225+
client.close()
202226

227+
# Case 2: Existing collection with a validator
228+
client = MongoClient(MONGODB_URI)
229+
clxn_name = f"{COLLECTION_NAME}_validation"
230+
collection = client[DB_NAME][clxn_name]
231+
collection.delete_many({})
203232

233+
store = MongoDBGraphStore(
234+
collection=collection,
235+
entity_extraction_model=entity_extraction_model,
236+
validate=True,
237+
validation_action="error",
238+
)
239+
bulkwrite_results = store.add_documents(documents)
240+
assert len(bulkwrite_results) == len(documents)
241+
collection.drop()
242+
client.close()
243+
244+
# Case 3: Existing collection without a validator
245+
client = MongoClient(MONGODB_URI)
246+
clxn_name = f"{COLLECTION_NAME}_validation"
247+
collection = client[DB_NAME].create_collection(clxn_name)
248+
store = MongoDBGraphStore(
249+
collection=collection,
250+
entity_extraction_model=entity_extraction_model,
251+
validate=True,
252+
validation_action="error",
253+
)
254+
bulkwrite_results = store.add_documents(documents)
255+
assert len(bulkwrite_results) == len(documents)
256+
client.close()
257+
258+
259+
@flaky
204260
def test_allowed_entity_types(documents, entity_extraction_model):
205261
"""Add allowed_entity_types. Use the validator to confirm behaviour."""
206262
allowed_entity_types = ["Person"]
263+
# drop collection
207264
client = MongoClient(MONGODB_URI)
208265
collection_name = f"{COLLECTION_NAME}_allowed_entity_types"
209266
client[DB_NAME][collection_name].drop()
210-
collection = client[DB_NAME].create_collection(collection_name)
267+
# create knowledge graph with only allowed_entity_types
268+
# this changes the schema at runtime
211269
store = MongoDBGraphStore(
212-
collection,
213-
entity_extraction_model,
214270
allowed_entity_types=allowed_entity_types,
215271
validate=True,
272+
validation_action="error",
273+
connection_string=MONGODB_URI,
274+
database_name=DB_NAME,
275+
collection_name=collection_name,
276+
entity_extraction_model=entity_extraction_model,
216277
)
217278
bulkwrite_results = store.add_documents(documents)
218279
assert len(bulkwrite_results) == len(documents)
@@ -223,16 +284,24 @@ def test_allowed_entity_types(documents, entity_extraction_model):
223284
all([len(e["relationships"].get("attributes", [])) == 0 for e in entities])
224285

225286

287+
@flaky
226288
def test_allowed_relationship_types(documents, entity_extraction_model):
289+
# drop collection
227290
client = MongoClient(MONGODB_URI)
228-
collection_name = f"{COLLECTION_NAME}_allowed_relationship_types"
229-
client[DB_NAME][collection_name].drop()
230-
collection = client[DB_NAME].create_collection(collection_name)
291+
clxn_name = f"{COLLECTION_NAME}_allowed_relationship_types"
292+
client[DB_NAME][clxn_name].drop()
293+
collection = client[DB_NAME].create_collection(clxn_name)
294+
collection.drop()
295+
# create knowledge graph with only allowed_relationship_types=["partner"]
296+
# this changes the schema at runtime
231297
store = MongoDBGraphStore(
232-
collection,
233-
entity_extraction_model,
234298
allowed_relationship_types=["partner"],
235299
validate=True,
300+
validation_action="error",
301+
connection_string=MONGODB_URI,
302+
database_name=DB_NAME,
303+
collection_name=clxn_name,
304+
entity_extraction_model=entity_extraction_model,
236305
)
237306
bulkwrite_results = store.add_documents(documents)
238307
assert len(bulkwrite_results) == len(documents)

libs/langchain-mongodb/uv.lock

Lines changed: 11 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)