Skip to content

Commit 7de1b3c

Browse files
authored
Merge branch 'main' into INTPYTHON-828
2 parents 3553ff1 + 965782f commit 7de1b3c

File tree

4 files changed

+73
-11
lines changed

4 files changed

+73
-11
lines changed

README.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@ It contains the following packages.
3838

3939
- Checkpointing (BaseCheckpointSaver)
4040
- [MongoDBSaver](https://langchain-mongodb.readthedocs.io/en/latest/langgraph_checkpoint_mongodb/saver/langgraph.checkpoint.mongodb.saver.MongoDBSaver.html#mongodbsaver)
41-
- [AsyncMongoDBSaver](https://langchain-mongodb.readthedocs.io/en/latest/langgraph_checkpoint_mongodb/aio/langgraph.checkpoint.mongodb.aio.AsyncMongoDBSaver.html#asyncmongodbsaver)
4241

4342
- Long-term memory (BaseStore)
4443
- [MongoDBStore](https://langchain-mongodb.readthedocs.io/en/latest/langgraph_store_mongodb/base/langgraph.store.mongodb.base.MongoDBStore.html#langgraph.store.mongodb.base.MongoDBStore)

libs/langgraph-checkpoint-mongodb/langgraph/checkpoint/mongodb/saver.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
CheckpointTuple,
1818
get_checkpoint_id,
1919
)
20+
from langgraph.checkpoint.serde.base import SerializerProtocol
21+
from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer
2022
from pymongo import ASCENDING, MongoClient, UpdateOne
2123
from pymongo.database import Database as MongoDatabase
2224

@@ -81,6 +83,7 @@ def __init__(
8183
checkpoint_collection_name: str = "checkpoints",
8284
writes_collection_name: str = "checkpoint_writes",
8385
ttl: Optional[int] = None,
86+
serde: SerializerProtocol | None = None,
8487
**kwargs: Any,
8588
) -> None:
8689
super().__init__()
@@ -89,6 +92,10 @@ def __init__(
8992
self.checkpoint_collection = self.db[checkpoint_collection_name]
9093
self.writes_collection = self.db[writes_collection_name]
9194
self.ttl = ttl
95+
if serde is not None:
96+
self.serde = serde
97+
else:
98+
self.serde = JsonPlusSerializer()
9299

93100
# Create indexes if not present
94101
if len(self.checkpoint_collection.list_indexes().to_list()) < 2:
@@ -236,7 +243,7 @@ def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
236243
return CheckpointTuple(
237244
{"configurable": config_values},
238245
checkpoint,
239-
loads_metadata(doc["metadata"]),
246+
loads_metadata(self.serde, doc["metadata"]),
240247
(
241248
{
242249
"configurable": {
@@ -291,7 +298,7 @@ def list(
291298

292299
if filter:
293300
for key, value in filter.items():
294-
query[f"metadata.{key}"] = dumps_metadata(value)
301+
query[f"metadata.{key}"] = dumps_metadata(self.serde, value)
295302

296303
if before is not None:
297304
query["checkpoint_id"] = {"$lt": before["configurable"]["checkpoint_id"]}
@@ -325,7 +332,7 @@ def list(
325332
}
326333
},
327334
checkpoint=self.serde.loads_typed((doc["type"], doc["checkpoint"])),
328-
metadata=loads_metadata(doc["metadata"]),
335+
metadata=loads_metadata(self.serde, doc["metadata"]),
329336
parent_config=(
330337
{
331338
"configurable": {
@@ -381,7 +388,7 @@ def put(
381388
"parent_checkpoint_id": config["configurable"].get("checkpoint_id"),
382389
"type": type_,
383390
"checkpoint": serialized_checkpoint,
384-
"metadata": dumps_metadata(metadata),
391+
"metadata": dumps_metadata(self.serde, metadata),
385392
}
386393
upsert_query = {
387394
"thread_id": thread_id,

libs/langgraph-checkpoint-mongodb/langgraph/checkpoint/mongodb/utils.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,9 @@
88

99
from langgraph.checkpoint.base import CheckpointMetadata
1010
from langgraph.checkpoint.serde.base import SerializerProtocol
11-
from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer
1211
from pymongo import AsyncMongoClient
1312
from pymongo.driver_info import DriverInfo
1413

15-
serde: SerializerProtocol = JsonPlusSerializer()
16-
1714
DRIVER_METADATA = DriverInfo(
1815
name="Langgraph", version=version("langgraph-checkpoint-mongodb")
1916
)
@@ -25,7 +22,9 @@ def _append_client_metadata(client: AsyncMongoClient) -> None:
2522
client.append_metadata(DRIVER_METADATA)
2623

2724

28-
def loads_metadata(metadata: dict[str, Any]) -> CheckpointMetadata:
25+
def loads_metadata(
26+
serde: SerializerProtocol, metadata: dict[str, Any]
27+
) -> CheckpointMetadata:
2928
"""Deserialize metadata document
3029
3130
The CheckpointMetadata class itself cannot be stored directly in MongoDB,
@@ -38,13 +37,14 @@ def loads_metadata(metadata: dict[str, Any]) -> CheckpointMetadata:
3837
if isinstance(metadata, dict):
3938
output = dict()
4039
for key, value in metadata.items():
41-
output[key] = loads_metadata(value)
40+
output[key] = loads_metadata(serde, value)
4241
return output
4342
else:
4443
return serde.loads_typed(metadata)
4544

4645

4746
def dumps_metadata(
47+
serde: SerializerProtocol,
4848
metadata: Union[CheckpointMetadata, Any],
4949
) -> Union[bytes, dict[str, Any]]:
5050
"""Serialize all values in metadata dictionary.
@@ -54,7 +54,7 @@ def dumps_metadata(
5454
if isinstance(metadata, dict):
5555
output = dict()
5656
for key, value in metadata.items():
57-
output[key] = dumps_metadata(value)
57+
output[key] = dumps_metadata(serde, value)
5858
return output
5959
else:
6060
return serde.dumps_typed(metadata)
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import os
2+
from typing import Any
3+
4+
from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer
5+
from pymongo import MongoClient
6+
7+
from langgraph.checkpoint.mongodb import MongoDBSaver
8+
9+
MONGODB_URI = os.environ.get(
10+
"MONGODB_URI", "mongodb://localhost:27017/?directConnection=true"
11+
)
12+
DB_NAME = os.environ.get("DB_NAME", "langgraph-test")
13+
COLLECTION_NAME = "serde_checkpoints"
14+
15+
16+
class CustomSerializer(JsonPlusSerializer):
17+
def __init__(self) -> None:
18+
super().__init__()
19+
self.dumps_called = False
20+
self.loads_called = False
21+
22+
def dumps_typed(self, obj: Any) -> tuple[str, bytes]:
23+
self.dumps_called = True
24+
return super().dumps_typed(obj)
25+
26+
def loads_typed(self, obj: tuple[str, bytes]) -> Any:
27+
self.loads_called = True
28+
return super().loads_typed(obj)
29+
30+
31+
def test_custom_serde(input_data: dict[str, Any]) -> None:
32+
client: MongoClient = MongoClient(MONGODB_URI)
33+
db = client[DB_NAME]
34+
db.drop_collection(COLLECTION_NAME)
35+
36+
custom_serializer = CustomSerializer()
37+
38+
with MongoDBSaver.from_conn_string(
39+
MONGODB_URI, DB_NAME, COLLECTION_NAME, serde=custom_serializer
40+
) as saver:
41+
put_config = saver.put(
42+
input_data["config_1"],
43+
input_data["chkpnt_1"],
44+
input_data["metadata_1"],
45+
{},
46+
)
47+
48+
assert custom_serializer.dumps_called
49+
50+
retrieved_checkpoint_tuple = saver.get_tuple(put_config)
51+
52+
assert custom_serializer.loads_called
53+
54+
assert retrieved_checkpoint_tuple is not None
55+
assert retrieved_checkpoint_tuple.checkpoint == input_data["chkpnt_1"]
56+
assert retrieved_checkpoint_tuple.metadata == input_data["metadata_1"]

0 commit comments

Comments
 (0)