|
20 | 20 | from langgraph.checkpoint.serde.base import SerializerProtocol |
21 | 21 | from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer |
22 | 22 | from pymongo import ASCENDING, MongoClient, UpdateOne |
| 23 | +from pymongo.collection import Collection |
23 | 24 | from pymongo.database import Database as MongoDatabase |
24 | 25 |
|
25 | 26 | from .utils import DRIVER_METADATA, dumps_metadata, loads_metadata |
26 | 27 |
|
27 | 28 |
|
| 29 | +def _create_saver_indexes( |
| 30 | + collection: Collection, |
| 31 | + compound_index: list[tuple[str, int]], |
| 32 | + ttl: Optional[int] = None, |
| 33 | +) -> None: |
| 34 | + """Create indexes for the saver collections. |
| 35 | +
|
| 36 | + This helper function creates the given compound index and TTL index (if required) |
| 37 | + for the given collection. |
| 38 | +
|
| 39 | + Args: |
| 40 | + collection (Collection): The MongoDB collection to create indexes on. |
| 41 | + compound_index (list[tuple[str, int]]): The compound index to create. |
| 42 | + ttl (int, optional): Time to live in seconds for the TTL index. Defaults to None. |
| 43 | + """ |
| 44 | + |
| 45 | + def index_key_list(index: Any) -> list[tuple[str, int]]: |
| 46 | + return list((k, v) for k, v in index["key"].items()) |
| 47 | + |
| 48 | + indexes = list(collection.list_indexes()) |
| 49 | + index_keys = [index_key_list(idx) for idx in indexes] |
| 50 | + if compound_index not in index_keys: |
| 51 | + collection.create_index(compound_index, unique=True) |
| 52 | + if ttl is not None: |
| 53 | + ttl_index = [("created_at", ASCENDING)] |
| 54 | + found = False |
| 55 | + for idx in indexes: |
| 56 | + if ( |
| 57 | + index_key_list(idx) == tuple(ttl_index) |
| 58 | + and idx.get("expireAfterSeconds") == ttl |
| 59 | + ): |
| 60 | + found = True |
| 61 | + break |
| 62 | + if not found: |
| 63 | + collection.create_index(ttl_index, expireAfterSeconds=ttl) |
| 64 | + |
| 65 | + |
28 | 66 | class MongoDBSaver(BaseCheckpointSaver): |
29 | 67 | """A checkpointer that stores StateGraph checkpoints in a MongoDB database. |
30 | 68 |
|
@@ -97,34 +135,22 @@ def __init__( |
97 | 135 | else: |
98 | 136 | self.serde = JsonPlusSerializer() |
99 | 137 |
|
100 | | - # Create indexes if not present |
101 | | - if len(self.checkpoint_collection.list_indexes().to_list()) < 2: |
102 | | - self.checkpoint_collection.create_index( |
103 | | - keys=[("thread_id", 1), ("checkpoint_ns", 1), ("checkpoint_id", -1)], |
104 | | - unique=True, |
105 | | - ) |
106 | | - if self.ttl: |
107 | | - self.checkpoint_collection.create_index( |
108 | | - keys=[("created_at", ASCENDING)], |
109 | | - expireAfterSeconds=self.ttl, |
110 | | - ) |
111 | | - |
112 | | - if len(self.writes_collection.list_indexes().to_list()) < 2: |
113 | | - self.writes_collection.create_index( |
114 | | - keys=[ |
115 | | - ("thread_id", 1), |
116 | | - ("checkpoint_ns", 1), |
117 | | - ("checkpoint_id", -1), |
118 | | - ("task_id", 1), |
119 | | - ("idx", 1), |
120 | | - ], |
121 | | - unique=True, |
122 | | - ) |
123 | | - if self.ttl: |
124 | | - self.writes_collection.create_index( |
125 | | - keys=[("created_at", ASCENDING)], |
126 | | - expireAfterSeconds=self.ttl, |
127 | | - ) |
| 138 | + _create_saver_indexes( |
| 139 | + self.checkpoint_collection, |
| 140 | + [("thread_id", 1), ("checkpoint_ns", 1), ("checkpoint_id", -1)], |
| 141 | + self.ttl, |
| 142 | + ) |
| 143 | + _create_saver_indexes( |
| 144 | + self.writes_collection, |
| 145 | + [ |
| 146 | + ("thread_id", 1), |
| 147 | + ("checkpoint_ns", 1), |
| 148 | + ("checkpoint_id", -1), |
| 149 | + ("task_id", 1), |
| 150 | + ("idx", 1), |
| 151 | + ], |
| 152 | + self.ttl, |
| 153 | + ) |
128 | 154 |
|
129 | 155 | @classmethod |
130 | 156 | @contextmanager |
|
0 commit comments