Skip to content

Commit 17c39bd

Browse files
INTPYTHON-661 Tests Async use of MongoDBStore (#176)
1 parent a2227cd commit 17c39bd

File tree

3 files changed

+478
-57
lines changed

3 files changed

+478
-57
lines changed

libs/langgraph-store-mongodb/tests/integration_tests/test_store_semantic.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,16 @@
2727
DIMENSIONS = 5 # Dimensions of embedding model
2828

2929

30-
def wait(cond: Callable, timeout: int = 15, interval: int = 1) -> None:
30+
def wait_until(
31+
predicate: Callable, timeout: int = TIMEOUT, interval: int = INTERVAL
32+
) -> None:
3133
start = monotonic()
3234
while monotonic() - start < timeout:
33-
if cond():
35+
if predicate():
3436
return
3537
else:
3638
sleep(interval)
37-
raise TimeoutError("timeout waiting for: ", cond)
39+
raise TimeoutError("timeout waiting for predicate: ", predicate)
3840

3941

4042
class StaticEmbeddings(Embeddings):
@@ -59,12 +61,12 @@ def collection() -> Generator[Collection, None, None]:
5961
db = client[DB_NAME]
6062
db.drop_collection(COLLECTION_NAME)
6163
collection = db.create_collection(COLLECTION_NAME)
62-
wait(lambda: collection.count_documents({}) == 0, TIMEOUT, INTERVAL)
64+
wait_until(lambda: collection.count_documents({}) == 0, TIMEOUT, INTERVAL)
6365
try:
6466
collection.drop_search_index(INDEX_NAME)
6567
except OperationFailure:
6668
pass
67-
wait(
69+
wait_until(
6870
lambda: len(collection.list_search_indexes().to_list()) == 0, TIMEOUT, INTERVAL
6971
)
7072

@@ -116,7 +118,7 @@ def test_filters(collection: Collection) -> None:
116118
query = "What is the grade of our pears?"
117119
# Case 1: fields is a string:
118120
namespace_prefix = ("a",) # filter ("a",) catches all docs
119-
wait(
121+
wait_until(
120122
lambda: len(store_mdb.search(namespace_prefix, query=query)) == len(products),
121123
TIMEOUT,
122124
INTERVAL,
Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
import os
2+
from collections.abc import Generator
3+
from time import monotonic, sleep
4+
from typing import Callable
5+
6+
import pytest
7+
from langchain_core.embeddings import Embeddings
8+
from pymongo import MongoClient
9+
from pymongo.collection import Collection
10+
from pymongo.errors import OperationFailure
11+
12+
from langgraph.store.base import PutOp
13+
from langgraph.store.memory import InMemoryStore
14+
from langgraph.store.mongodb import (
15+
MongoDBStore,
16+
create_vector_index_config,
17+
)
18+
19+
MONGODB_URI = os.environ.get(
20+
"MONGODB_URI", "mongodb://localhost:27017?directConnection=true"
21+
)
22+
DB_NAME = os.environ.get("DB_NAME", "langgraph-test")
23+
COLLECTION_NAME = "semantic_search_async"
24+
INDEX_NAME = "vector_index"
25+
TIMEOUT, INTERVAL = 30, 1 # timeout to index new data
26+
27+
DIMENSIONS = 5 # Dimensions of embedding model
28+
29+
30+
def wait_until(
31+
predicate: Callable, timeout: int = TIMEOUT, interval: int = INTERVAL
32+
) -> None:
33+
start = monotonic()
34+
while monotonic() - start < timeout:
35+
if predicate():
36+
return
37+
else:
38+
sleep(interval)
39+
raise TimeoutError("timeout waiting for predicate: ", predicate)
40+
41+
42+
class StaticEmbeddings(Embeddings):
43+
"""ANN Search is not tested here. That is done in langchain-mongodb."""
44+
45+
def embed_documents(self, texts: list[str]) -> list[list[float]]:
46+
vectors = []
47+
for txt in texts:
48+
vectors.append(self.embed_query(txt))
49+
return vectors
50+
51+
def embed_query(self, text: str) -> list[float]:
52+
if "pears" in text:
53+
return [1.0] + [0.5] * (DIMENSIONS - 1)
54+
else:
55+
return [0.5] * DIMENSIONS
56+
57+
58+
@pytest.fixture
59+
def collection() -> Generator[Collection, None, None]:
60+
client: MongoClient = MongoClient(MONGODB_URI)
61+
db = client[DB_NAME]
62+
db.drop_collection(COLLECTION_NAME)
63+
collection = db.create_collection(COLLECTION_NAME)
64+
wait_until(lambda: collection.count_documents({}) == 0, TIMEOUT, INTERVAL)
65+
try:
66+
collection.drop_search_index(INDEX_NAME)
67+
except OperationFailure:
68+
pass
69+
wait_until(
70+
lambda: len(collection.list_search_indexes().to_list()) == 0, TIMEOUT, INTERVAL
71+
)
72+
73+
yield collection
74+
75+
client.close()
76+
77+
78+
async def test_filters(collection: Collection) -> None:
79+
"""Test permutations of namespace_prefix in filter."""
80+
81+
index_config = create_vector_index_config(
82+
name=INDEX_NAME,
83+
dims=DIMENSIONS,
84+
fields=["product"],
85+
embed=StaticEmbeddings(), # embedding
86+
filters=["metadata.available"],
87+
)
88+
store_mdb = MongoDBStore(
89+
collection, index_config=index_config, auto_index_timeout=TIMEOUT
90+
)
91+
store_in_mem = InMemoryStore(index=index_config)
92+
93+
namespaces = [
94+
("a",),
95+
("a", "b", "c"),
96+
("a", "b", "c", "d"),
97+
]
98+
99+
products = ["apples", "oranges", "pears"]
100+
101+
# Add some indexed data
102+
put_ops = []
103+
for i, ns in enumerate(namespaces):
104+
put_ops.append(
105+
PutOp(
106+
namespace=ns,
107+
key=f"id_{i}",
108+
value={
109+
"product": products[i],
110+
"metadata": {"available": bool(i % 2), "grade": "A" * (i + 1)},
111+
},
112+
)
113+
)
114+
115+
await store_mdb.abatch(put_ops)
116+
store_in_mem.batch(put_ops)
117+
118+
query = "What is the grade of our pears?"
119+
# Case 1: fields is a string:
120+
namespace_prefix = ("a",) # filter ("a",) catches all docs
121+
122+
# In our first search, we'll retry until the mongos has indexed the new docs
123+
start = monotonic()
124+
indexed = False
125+
while monotonic() - start < TIMEOUT:
126+
if len(await store_mdb.asearch(namespace_prefix, query=query)) == len(products):
127+
indexed = True
128+
break
129+
else:
130+
sleep(INTERVAL)
131+
if not indexed:
132+
raise TimeoutError("timeout waiting for: vector_index")
133+
134+
result_mdb = await store_mdb.asearch(namespace_prefix, query=query)
135+
assert result_mdb[0].value["product"] == "pears" # test sorted by score
136+
137+
result_mem = store_in_mem.search(namespace_prefix, query=query)
138+
assert len(result_mem) == len(products)
139+
140+
# Case 2: filter on 2nd namespace in hierarchy
141+
namespace_prefix = ("a", "b")
142+
result_mem = await store_in_mem.asearch(namespace_prefix, query=query)
143+
result_mdb = store_mdb.search(namespace_prefix, query=query)
144+
# filter ("a",) catches all docs
145+
assert len(result_mem) == len(result_mdb) == len(products) - 1
146+
assert result_mdb[0].value["product"] == "pears"
147+
148+
# Case 3: Empty namespace_prefix
149+
namespace_prefix = ("",)
150+
result_mem = store_in_mem.search(namespace_prefix, query=query)
151+
result_mdb = await store_mdb.asearch(namespace_prefix, query=query)
152+
assert len(result_mem) == len(result_mdb) == 0
153+
154+
# Case 4: With filter on value (nested)
155+
namespace_prefix = ("a",)
156+
available = {"metadata.available": True}
157+
result_mdb = await store_mdb.asearch(
158+
namespace_prefix, query=query, filter=available
159+
)
160+
assert result_mdb[0].value["product"] == "oranges"
161+
assert len(result_mdb) == 1

0 commit comments

Comments
 (0)