Skip to content

Commit 8a8e580

Browse files
authored
INTPYTHON-886 - Fix LangGraph checkpointer index creation logic (langchain-ai#313)
INTPYTHON-886
1 parent 54c2150 commit 8a8e580

File tree

2 files changed

+98
-28
lines changed

2 files changed

+98
-28
lines changed

libs/langgraph-checkpoint-mongodb/langgraph/checkpoint/mongodb/saver.py

Lines changed: 54 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,49 @@
2020
from langgraph.checkpoint.serde.base import SerializerProtocol
2121
from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer
2222
from pymongo import ASCENDING, MongoClient, UpdateOne
23+
from pymongo.collection import Collection
2324
from pymongo.database import Database as MongoDatabase
2425

2526
from .utils import DRIVER_METADATA, dumps_metadata, loads_metadata
2627

2728

29+
def _create_saver_indexes(
30+
collection: Collection,
31+
compound_index: list[tuple[str, int]],
32+
ttl: Optional[int] = None,
33+
) -> None:
34+
"""Create indexes for the saver collections.
35+
36+
This helper function creates the given compound index and TTL index (if required)
37+
for the given collection.
38+
39+
Args:
40+
collection (Collection): The MongoDB collection to create indexes on.
41+
compound_index (list[tuple[str, int]]): The compound index to create.
42+
ttl (int, optional): Time to live in seconds for the TTL index. Defaults to None.
43+
"""
44+
45+
def index_key_list(index: Any) -> list[tuple[str, int]]:
46+
return list((k, v) for k, v in index["key"].items())
47+
48+
indexes = list(collection.list_indexes())
49+
index_keys = [index_key_list(idx) for idx in indexes]
50+
if compound_index not in index_keys:
51+
collection.create_index(compound_index, unique=True)
52+
if ttl is not None:
53+
ttl_index = [("created_at", ASCENDING)]
54+
found = False
55+
for idx in indexes:
56+
if (
57+
index_key_list(idx) == tuple(ttl_index)
58+
and idx.get("expireAfterSeconds") == ttl
59+
):
60+
found = True
61+
break
62+
if not found:
63+
collection.create_index(ttl_index, expireAfterSeconds=ttl)
64+
65+
2866
class MongoDBSaver(BaseCheckpointSaver):
2967
"""A checkpointer that stores StateGraph checkpoints in a MongoDB database.
3068
@@ -97,34 +135,22 @@ def __init__(
97135
else:
98136
self.serde = JsonPlusSerializer()
99137

100-
# Create indexes if not present
101-
if len(self.checkpoint_collection.list_indexes().to_list()) < 2:
102-
self.checkpoint_collection.create_index(
103-
keys=[("thread_id", 1), ("checkpoint_ns", 1), ("checkpoint_id", -1)],
104-
unique=True,
105-
)
106-
if self.ttl:
107-
self.checkpoint_collection.create_index(
108-
keys=[("created_at", ASCENDING)],
109-
expireAfterSeconds=self.ttl,
110-
)
111-
112-
if len(self.writes_collection.list_indexes().to_list()) < 2:
113-
self.writes_collection.create_index(
114-
keys=[
115-
("thread_id", 1),
116-
("checkpoint_ns", 1),
117-
("checkpoint_id", -1),
118-
("task_id", 1),
119-
("idx", 1),
120-
],
121-
unique=True,
122-
)
123-
if self.ttl:
124-
self.writes_collection.create_index(
125-
keys=[("created_at", ASCENDING)],
126-
expireAfterSeconds=self.ttl,
127-
)
138+
_create_saver_indexes(
139+
self.checkpoint_collection,
140+
[("thread_id", 1), ("checkpoint_ns", 1), ("checkpoint_id", -1)],
141+
self.ttl,
142+
)
143+
_create_saver_indexes(
144+
self.writes_collection,
145+
[
146+
("thread_id", 1),
147+
("checkpoint_ns", 1),
148+
("checkpoint_id", -1),
149+
("task_id", 1),
150+
("idx", 1),
151+
],
152+
self.ttl,
153+
)
128154

129155
@classmethod
130156
@contextmanager

libs/langgraph-checkpoint-mongodb/tests/unit_tests/test_sync.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,3 +209,47 @@ def test_ttl(input_data: dict[str, Any]) -> None:
209209
saver.checkpoint_collection.drop_indexes()
210210
saver.writes_collection.delete_many({})
211211
saver.writes_collection.drop_indexes()
212+
213+
214+
def test_init_creates_indexes() -> None:
215+
client: MongoClient = MongoClient(MONGODB_URI)
216+
db = client[DB_NAME]
217+
checkpoint_coll = "checkpoints_test"
218+
writes_coll = "writes_test"
219+
220+
db.drop_collection(checkpoint_coll)
221+
db.drop_collection(writes_coll)
222+
223+
ttl = 100
224+
with MongoDBSaver.from_conn_string(
225+
MONGODB_URI, DB_NAME, checkpoint_coll, writes_coll, ttl=ttl
226+
) as saver:
227+
cp_indexes = saver.checkpoint_collection.index_information()
228+
wr_indexes = saver.writes_collection.index_information()
229+
230+
def _has_index(index_info: Any, keys: list[tuple[str, int]]) -> bool:
231+
for _, info in index_info.items():
232+
if info.get("key") == keys:
233+
return True
234+
return False
235+
236+
expected_cp_keys = [
237+
("thread_id", 1),
238+
("checkpoint_ns", 1),
239+
("checkpoint_id", -1),
240+
]
241+
assert _has_index(cp_indexes, expected_cp_keys)
242+
assert _has_index(cp_indexes, [("created_at", 1)])
243+
244+
expected_wr_keys = [
245+
("thread_id", 1),
246+
("checkpoint_ns", 1),
247+
("checkpoint_id", -1),
248+
("task_id", 1),
249+
("idx", 1),
250+
]
251+
assert _has_index(wr_indexes, expected_wr_keys)
252+
assert _has_index(wr_indexes, [("created_at", 1)])
253+
254+
db.drop_collection(checkpoint_coll)
255+
db.drop_collection(writes_coll)

0 commit comments

Comments
 (0)