|
1 | 1 | import os |
2 | | -from collections.abc import Generator |
3 | | -from typing import Union |
| 2 | +from collections.abc import AsyncGenerator |
| 3 | +from datetime import datetime |
4 | 4 |
|
5 | | -import pytest |
| 5 | +import pytest_asyncio |
6 | 6 | from pymongo import MongoClient |
7 | 7 |
|
8 | 8 | from langgraph.store.base import ( |
9 | 9 | GetOp, |
| 10 | + Item, |
10 | 11 | ListNamespacesOp, |
| 12 | + MatchCondition, |
11 | 13 | PutOp, |
12 | | - SearchOp, |
13 | 14 | TTLConfig, |
14 | 15 | ) |
15 | 16 | from langgraph.store.mongodb import ( |
|
20 | 21 | "MONGODB_URI", "mongodb://localhost:27017?directConnection=true" |
21 | 22 | ) |
22 | 23 | DB_NAME = os.environ.get("DB_NAME", "langgraph-test") |
23 | | -COLLECTION_NAME = "async_store" |
| 24 | +COLLECTION_NAME = "long_term_memory_aio" |
24 | 25 |
|
25 | 26 |
|
26 | | -@pytest.fixture |
27 | | -def store() -> Generator: |
| 27 | +t0 = (datetime(2025, 4, 7, 17, 29, 10, 0),) |
| 28 | + |
| 29 | + |
| 30 | +@pytest_asyncio.fixture |
| 31 | +async def store() -> AsyncGenerator: |
28 | 32 | """Create a simple store following that in base's test_list_namespaces_basic""" |
29 | 33 | client: MongoClient = MongoClient(MONGODB_URI) |
30 | 34 | collection = client[DB_NAME][COLLECTION_NAME] |
31 | 35 | collection.delete_many({}) |
32 | 36 | collection.drop_indexes() |
33 | 37 |
|
34 | | - yield MongoDBStore( |
| 38 | + mdbstore = MongoDBStore( |
35 | 39 | collection, |
36 | 40 | ttl_config=TTLConfig(default_ttl=3600, refresh_on_read=True), |
37 | 41 | ) |
38 | 42 |
|
| 43 | + namespaces = [ |
| 44 | + ("a", "b", "c"), |
| 45 | + ("a", "b", "d", "e"), |
| 46 | + ("a", "b", "d", "i"), |
| 47 | + ("a", "b", "f"), |
| 48 | + ("a", "c", "f"), |
| 49 | + ("b", "a", "f"), |
| 50 | + ("users", "123"), |
| 51 | + ("users", "456", "settings"), |
| 52 | + ("admin", "users", "789"), |
| 53 | + ] |
| 54 | + for i, ns in enumerate(namespaces): |
| 55 | + await mdbstore.aput( |
| 56 | + namespace=ns, key=f"id_{i}", value={"data": f"value_{i:02d}"} |
| 57 | + ) |
| 58 | + |
| 59 | + yield mdbstore |
| 60 | + |
39 | 61 | if client: |
40 | 62 | client.close() |
41 | 63 |
|
42 | 64 |
|
43 | | -async def test_batch_async(store: MongoDBStore) -> None: |
44 | | - N = 100 |
45 | | - M = 5 |
46 | | - ops: list[Union[PutOp, GetOp, ListNamespacesOp, SearchOp]] = [] |
47 | | - for m in range(M): |
48 | | - for i in range(N): |
49 | | - ops.append( |
50 | | - PutOp( |
51 | | - ("test", "foo", "bar", "baz", str(m % 2)), |
52 | | - f"key{i}", |
53 | | - value={"foo": "bar" + str(i)}, |
54 | | - ) |
55 | | - ) |
56 | | - ops.append( |
57 | | - GetOp( |
58 | | - ("test", "foo", "bar", "baz", str(m % 2)), |
59 | | - f"key{i}", |
60 | | - ) |
61 | | - ) |
62 | | - ops.append( |
63 | | - ListNamespacesOp( |
64 | | - match_conditions=None, |
65 | | - max_depth=m + 1, |
66 | | - ) |
67 | | - ) |
68 | | - ops.append( |
69 | | - SearchOp( |
70 | | - ("test",), |
71 | | - ) |
72 | | - ) |
73 | | - ops.append( |
74 | | - PutOp( |
75 | | - ("test", "foo", "bar", "baz", str(m % 2)), |
76 | | - f"key{i}", |
77 | | - value={"foo": "bar" + str(i)}, |
78 | | - ) |
79 | | - ) |
80 | | - ops.append( |
81 | | - PutOp(("test", "foo", "bar", "baz", str(m % 2)), f"key{i}", None) |
82 | | - ) |
83 | | - |
84 | | - results = await store.abatch(ops) |
85 | | - assert len(results) == M * N * 6 |
| 65 | +async def test_alist_namespaces(store: MongoDBStore) -> None: |
| 66 | + result = await store.alist_namespaces(prefix=("a", "b")) |
| 67 | + expected = [ |
| 68 | + ("a", "b", "c"), |
| 69 | + ("a", "b", "d", "e"), |
| 70 | + ("a", "b", "d", "i"), |
| 71 | + ("a", "b", "f"), |
| 72 | + ] |
| 73 | + assert sorted(result) == sorted(expected) |
| 74 | + |
| 75 | + result = await store.alist_namespaces(suffix=("f",)) |
| 76 | + expected = [ |
| 77 | + ("a", "b", "f"), |
| 78 | + ("a", "c", "f"), |
| 79 | + ("b", "a", "f"), |
| 80 | + ] |
| 81 | + assert sorted(result) == sorted(expected) |
| 82 | + |
| 83 | + result = await store.alist_namespaces(prefix=("a",), suffix=("f",)) |
| 84 | + expected = [ |
| 85 | + ("a", "b", "f"), |
| 86 | + ("a", "c", "f"), |
| 87 | + ] |
| 88 | + assert sorted(result) == sorted(expected) |
| 89 | + |
| 90 | + result = await store.alist_namespaces( |
| 91 | + prefix=("a",), |
| 92 | + suffix=( |
| 93 | + "b", |
| 94 | + "f", |
| 95 | + ), |
| 96 | + ) |
| 97 | + expected = [("a", "b", "f")] |
| 98 | + assert sorted(result) == sorted(expected) |
| 99 | + |
| 100 | + # Test max_depth and deduplication |
| 101 | + result = await store.alist_namespaces(prefix=("a", "b"), max_depth=3) |
| 102 | + expected = [ |
| 103 | + ("a", "b", "c"), |
| 104 | + ("a", "b", "d"), |
| 105 | + ("a", "b", "f"), |
| 106 | + ] |
| 107 | + assert sorted(result) == sorted(expected) |
| 108 | + |
| 109 | + result = await store.alist_namespaces(prefix=("a", "*", "f")) |
| 110 | + expected = [ |
| 111 | + ("a", "b", "f"), |
| 112 | + ("a", "c", "f"), |
| 113 | + ] |
| 114 | + assert sorted(result) == sorted(expected) |
| 115 | + |
| 116 | + result = await store.alist_namespaces(prefix=("*", "*", "f")) |
| 117 | + expected = [("a", "c", "f"), ("b", "a", "f"), ("a", "b", "f")] |
| 118 | + assert sorted(result) == sorted(expected) |
| 119 | + |
| 120 | + result = await store.alist_namespaces(suffix=("*", "f")) |
| 121 | + expected = [ |
| 122 | + ("a", "b", "f"), |
| 123 | + ("a", "c", "f"), |
| 124 | + ("b", "a", "f"), |
| 125 | + ] |
| 126 | + assert sorted(result) == sorted(expected) |
| 127 | + |
| 128 | + result = await store.alist_namespaces(prefix=("a", "b"), suffix=("d", "i")) |
| 129 | + expected = [("a", "b", "d", "i")] |
| 130 | + assert sorted(result) == sorted(expected) |
| 131 | + |
| 132 | + result = await store.alist_namespaces(prefix=("a", "b"), suffix=("i",)) |
| 133 | + expected = [("a", "b", "d", "i")] |
| 134 | + assert sorted(result) == sorted(expected) |
| 135 | + |
| 136 | + result = await store.alist_namespaces(prefix=("nonexistent",)) |
| 137 | + assert result == [] |
| 138 | + |
| 139 | + result = await store.alist_namespaces() |
| 140 | + assert len(result) == store.collection.count_documents({}) |
| 141 | + |
| 142 | + |
| 143 | +async def test_aget(store: MongoDBStore) -> None: |
| 144 | + result = store.get(namespace=("a", "b", "d", "i"), key="id_2") |
| 145 | + assert isinstance(result, Item) |
| 146 | + assert result.updated_at > result.created_at |
| 147 | + assert result.value == {"data": f"value_{2:02d}"} |
| 148 | + |
| 149 | + result = await store.aget(namespace=("a", "b", "d", "i"), key="id-2") |
| 150 | + assert result is None |
| 151 | + |
| 152 | + result = await store.aget(namespace=tuple(), key="id_2") |
| 153 | + assert result is None |
| 154 | + |
| 155 | + result = await store.aget(namespace=("a", "b", "d", "i"), key="") |
| 156 | + assert result is None |
| 157 | + |
| 158 | + # Test case: refresh_ttl is False |
| 159 | + result = store.collection.find_one(dict(namespace=["a", "b", "d", "i"], key="id_2")) |
| 160 | + assert result is not None |
| 161 | + expected_updated_at = result["updated_at"] |
| 162 | + |
| 163 | + result = await store.aget( |
| 164 | + namespace=("a", "b", "d", "i"), key="id_2", refresh_ttl=False |
| 165 | + ) |
| 166 | + assert result is not None |
| 167 | + assert result.updated_at == expected_updated_at |
| 168 | + |
| 169 | + |
| 170 | +async def test_ttl() -> None: |
| 171 | + namespace = ("a", "b", "c", "d", "e") |
| 172 | + key = "thread" |
| 173 | + value = {"human": "What is the weather in SF?", "ai": "It's always sunny in SF."} |
| 174 | + |
| 175 | + # refresh_on_read is True |
| 176 | + with MongoDBStore.from_conn_string( |
| 177 | + conn_string=MONGODB_URI, |
| 178 | + db_name=DB_NAME, |
| 179 | + collection_name=COLLECTION_NAME + "-ttl", |
| 180 | + ttl_config=TTLConfig(default_ttl=3600, refresh_on_read=True), |
| 181 | + ) as store: |
| 182 | + store.collection.delete_many({}) |
| 183 | + await store.aput(namespace=namespace, key=key, value=value) |
| 184 | + res = store.collection.find_one({}) |
| 185 | + assert res is not None |
| 186 | + orig_updated_at = res["updated_at"] |
| 187 | + res = await store.aget(namespace=namespace, key=key) |
| 188 | + assert res is not None |
| 189 | + found = store.collection.find_one({}) |
| 190 | + assert found is not None |
| 191 | + new_updated_at = found["updated_at"] |
| 192 | + assert new_updated_at > orig_updated_at |
| 193 | + assert res.updated_at == new_updated_at |
| 194 | + |
| 195 | + # refresh_on_read is False |
| 196 | + with MongoDBStore.from_conn_string( |
| 197 | + conn_string=MONGODB_URI, |
| 198 | + db_name=DB_NAME, |
| 199 | + collection_name=COLLECTION_NAME + "-ttl", |
| 200 | + ttl_config=TTLConfig(default_ttl=3600, refresh_on_read=False), |
| 201 | + ) as store: |
| 202 | + store.collection.delete_many({}) |
| 203 | + await store.aput(namespace=namespace, key=key, value=value) |
| 204 | + found = store.collection.find_one({}) |
| 205 | + assert found is not None |
| 206 | + orig_updated_at = found["updated_at"] |
| 207 | + res = await store.aget(namespace=namespace, key=key) |
| 208 | + assert res is not None |
| 209 | + found = store.collection.find_one({}) |
| 210 | + assert found is not None |
| 211 | + new_updated_at = found["updated_at"] |
| 212 | + assert new_updated_at == orig_updated_at |
| 213 | + assert res.updated_at == new_updated_at |
| 214 | + |
| 215 | + # ttl_config is None |
| 216 | + with MongoDBStore.from_conn_string( |
| 217 | + conn_string=MONGODB_URI, |
| 218 | + db_name=DB_NAME, |
| 219 | + collection_name=COLLECTION_NAME + "-ttl", |
| 220 | + ttl_config=None, |
| 221 | + ) as store: |
| 222 | + store.collection.delete_many({}) |
| 223 | + await store.aput(namespace=namespace, key=key, value=value) |
| 224 | + found = store.collection.find_one({}) |
| 225 | + assert found is not None |
| 226 | + orig_updated_at = found["updated_at"] |
| 227 | + res = await store.aget(namespace=namespace, key=key) |
| 228 | + assert res is not None |
| 229 | + found = store.collection.find_one({}) |
| 230 | + assert found is not None |
| 231 | + new_updated_at = found["updated_at"] |
| 232 | + assert new_updated_at > orig_updated_at |
| 233 | + assert res.updated_at == new_updated_at |
| 234 | + |
| 235 | + # refresh_on_read is True but refresh_ttl=False in get() |
| 236 | + with MongoDBStore.from_conn_string( |
| 237 | + conn_string=MONGODB_URI, |
| 238 | + db_name=DB_NAME, |
| 239 | + collection_name=COLLECTION_NAME + "-ttl", |
| 240 | + ttl_config=TTLConfig(default_ttl=3600, refresh_on_read=True), |
| 241 | + ) as store: |
| 242 | + store.collection.delete_many({}) |
| 243 | + await store.aput(namespace=namespace, key=key, value=value) |
| 244 | + found = store.collection.find_one({}) |
| 245 | + assert found is not None |
| 246 | + orig_updated_at = found["updated_at"] |
| 247 | + res = await store.aget(refresh_ttl=False, namespace=namespace, key=key) |
| 248 | + assert res is not None |
| 249 | + found = store.collection.find_one({}) |
| 250 | + assert found is not None |
| 251 | + new_updated_at = found["updated_at"] |
| 252 | + assert new_updated_at == orig_updated_at |
| 253 | + assert res.updated_at == new_updated_at |
| 254 | + |
| 255 | + |
| 256 | +async def test_aput(store: MongoDBStore) -> None: |
| 257 | + n = store.collection.count_documents({}) |
| 258 | + await store.aput(namespace=("a",), key=f"id_{n}", value={"data": f"value_{n:02d}"}) |
| 259 | + assert store.collection.count_documents({}) == n + 1 |
| 260 | + |
| 261 | + # include index kwarg |
| 262 | + await store.aput(("a",), "idx", {"data": "val"}, index=["data"]) |
| 263 | + assert store.collection.count_documents({}) == n + 2 |
| 264 | + |
| 265 | + |
| 266 | +async def test_adelete(store: MongoDBStore) -> None: |
| 267 | + n_items = store.collection.count_documents({}) |
| 268 | + await store.adelete(namespace=("a", "b", "c"), key="id_0") |
| 269 | + assert store.collection.count_documents({}) == n_items - 1 |
| 270 | + |
| 271 | + |
| 272 | +async def test_abatch() -> None: |
| 273 | + """Simple demonstration of order of batch operations. |
| 274 | +
|
| 275 | + Read operations, regardless of their order in the list of operations, |
| 276 | + act on the state of the database at the beginning of the batch. |
| 277 | + These include GetOp SearchOp, and ListNamespacesOp. |
| 278 | +
|
| 279 | + Write operations are applied only *after* reads! |
| 280 | +
|
| 281 | + Cases: |
| 282 | + PutOp |
| 283 | + GetOp |
| 284 | + ListNameSpaces after PutOp |
| 285 | + PutOp as delete after PutOp |
| 286 | +
|
| 287 | + raises: |
| 288 | + match_condition stuff |
| 289 | +
|
| 290 | + - check state after ops in different order |
| 291 | + """ |
| 292 | + namespace = ("a", "b", "c", "d", "e") |
| 293 | + key = "thread" |
| 294 | + value = {"human": "What is the weather in SF?", "ai": "It's always sunny in SF."} |
| 295 | + |
| 296 | + op_put = PutOp(namespace=namespace, key=key, value=value) |
| 297 | + op_del = PutOp(namespace=namespace, key=key, value=None) |
| 298 | + op_get = GetOp(namespace=namespace, key=key) |
| 299 | + cond_pre = MatchCondition(match_type="prefix", path=("a", "b")) |
| 300 | + cond_suf = MatchCondition(match_type="suffix", path=("d", "e")) |
| 301 | + op_list = ListNamespacesOp(match_conditions=(cond_pre, cond_suf)) |
| 302 | + |
| 303 | + with MongoDBStore.from_conn_string( |
| 304 | + conn_string=MONGODB_URI, |
| 305 | + db_name=DB_NAME, |
| 306 | + collection_name=COLLECTION_NAME, |
| 307 | + ttl_config=TTLConfig(default_ttl=3600, refresh_on_read=True), |
| 308 | + ) as store: |
| 309 | + # 1. Put 1, read it, list namespaces, and delete one item. |
| 310 | + # => not any(results) |
| 311 | + store.collection.delete_many({}) |
| 312 | + n_ops = 4 |
| 313 | + results = await store.abatch([op_put, op_get, op_list, op_del]) |
| 314 | + assert store.collection.count_documents({}) == 0 |
| 315 | + assert len(results) == n_ops |
| 316 | + assert not any(results) |
| 317 | + |
| 318 | + # 2. delete, put, get |
| 319 | + # => not any(results) |
| 320 | + n_ops = 3 |
| 321 | + results = await store.abatch([op_get, op_del, op_put]) |
| 322 | + assert store.collection.count_documents({}) == 1 |
| 323 | + assert len(results) == n_ops |
| 324 | + assert not any(results) |
| 325 | + |
| 326 | + # 3. delete, put, get |
| 327 | + # => get sees item from put in previous batch |
| 328 | + n_ops = 2 |
| 329 | + results = await store.abatch([op_del, op_get, op_list]) |
| 330 | + assert results[0] is None |
| 331 | + assert isinstance(results[1], Item) |
| 332 | + assert isinstance(results[2], list) and isinstance(results[2][0], tuple) |
| 333 | + |
| 334 | + |
| 335 | +async def test_asearch_basic(store: MongoDBStore) -> None: |
| 336 | + result = await store.asearch(("a", "b")) |
| 337 | + assert len(result) == 4 |
| 338 | + assert all(isinstance(res, Item) for res in result) |
| 339 | + |
| 340 | + namespace = ("a", "b", "c") |
| 341 | + await store.aput(namespace=namespace, key="id_foo", value={"data": "value_foo"}) |
| 342 | + result = await store.asearch(namespace, filter={"data": "value_foo"}) |
| 343 | + assert len(result) == 1 |
0 commit comments