Skip to content

Commit 43c3e99

Browse files
caseyclementsaclark4lifedependabot[bot]blink1073NoahStapp
authored
INTPYTHON-850 Add script to migrate checkpoint collections created before v0.2.2 (langchain-ai#291)
<!-- We do not accept pull requests that are primarily or substantially generated by AI tools (ChatGPT, Copilot, etc.). All contributions must be written and understood by human contributors. --> [INTPYTHON-850](https://jira.mongodb.org/browse/INTPYTHON-850) ## Summary <!-- What is this PR introducing? If context is already provided from the JIRA ticket, still place it in the Pull Request as you should not make the reviewer do digging for a basic summary. --> Addresses issue langchain-ai#287 by providing a migration script: `migrate_checkpoints_to_typed_metadata.py` Additional typing information was added in v0.2.2 to address [CVE - LangGraph Checkpoint affected by RCE in "json" mode of JsonPlusSerializer](https://osv.dev/vulnerability/GHSA-wwqv-p2pp-99h5). ## Changes in this PR ``` ~/src/langchain-mongodb/libs/langgraph-checkpoint-mongodb (INTPYTHON-850-SerializerMigration) $ uv run python scripts/migrate_checkpoints_to_typed_metadata.py -h usage: migrate_checkpoints_to_typed_metadata.py [-h] [--mongodb-uri MONGODB_URI] --db DB --collections COLLECTIONS [COLLECTIONS ...] [--batch-size BATCH_SIZE] [--suffix SUFFIX] [--workers WORKERS] [--dry-run] [--clear-destination] [--log-level {DEBUG,INFO,WARNING,ERROR,CRITICAL}] Migrate langgraph checkpoint metadata to typed format (>= v0.2.2). options: -h, --help show this help message and exit --mongodb-uri MONGODB_URI MongoDB connection URI --db DB Database name containing checkpoint collections --collections COLLECTIONS [COLLECTIONS ...] One or more checkpoint collection names to migrate --batch-size BATCH_SIZE Number of documents per insert batch --suffix SUFFIX Suffix for migrated collections (default: -new) --workers WORKERS Number of worker processes (default: 1) --dry-run Run migration without writing any data --clear-destination Delete destination collection before migrating --log-level {DEBUG,INFO,WARNING,ERROR,CRITICAL} Logging verbosity ``` <!-- What changes did you make to the code? What new APIs (public or private) were added, removed, or edited to generate the desired outcome explained in the above summary? --> ## Test Plan <!-- How did you test the code? If you added unit tests, you can say that. If you didn’t introduce unit tests, explain why. All code should be tested in some way – so please list what your validation strategy was. --> No new unit tests were created for this. Instead, work was done manually to create a test checkpoint collection in v0.2.2, migrate this collection with the script, and then run a modified version of tests/test_sync.py against the latest MongoDBSaver referencing the migrated collection. ## Checklist <!-- Do not delete the items provided on this checklist --> ### Checklist for Author - [x] Did you update the changelog (if necessary)? - [ ] Is the intention of the code captured in relevant tests? - [ ] If there are new TODOs, has a related JIRA ticket been created? - [ ] Has a MongoDB Employee run [the patch build of this PR](https://github.com/mongodb-labs/ai-ml-pipeline-testing?tab=readme-ov-file#running-a-patch-build-of-a-given-pr)? ### Checklist for Reviewer - [ ] Does the title of the PR reference a JIRA Ticket? - [ ] Do you fully understand the implementation? (Would you be comfortable explaining how this code works to someone else?) - [ ] Is all relevant documentation (README or docstring) updated? --------- Co-authored-by: Jeffrey A. Clark <aclark@aclark.net> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Steven Silvester <steve.silvester@mongodb.com> Co-authored-by: Noah Stapp <noah.stapp@mongodb.com>
1 parent d4594a2 commit 43c3e99

File tree

3 files changed

+289
-0
lines changed

3 files changed

+289
-0
lines changed

libs/langgraph-checkpoint-mongodb/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22

33
---
44

5+
## Changes in version 0.3.1 (DATE ???)
6+
- Fixes issue #287 to migrate checkpoint data created with v<0.2.2 with a migration script: [migrate_checkpoints_to_typed_metadata.py](./scripts/migrate_checkpoints_to_typed_metadata.py).
7+
58
## Changes in version 0.3.0 (2025/11/19)
69
- Allow custom serde objects to be passed to MongoDBSaver for serialization/deserialization.
710
- Remove the deprecated AsyncMongoDBSaver class, which has been replaced by MongoDBSaver's async methods.

libs/langgraph-checkpoint-mongodb/scripts/__init__.py

Whitespace-only changes.
Lines changed: 286 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,286 @@
1+
# /// script
2+
# requires-python = ">=3.10"
3+
# dependencies = [
4+
# "pymongo>=4.6,<5",
5+
# "langgraph-checkpoint-mongodb>=0.2.2",
6+
# ]
7+
# ///
8+
9+
"""Script to migrate metadata of checkpoint collections
10+
- from <=v0.2.1 which is json
11+
- to >=v0.2.2 which is typed (defaulting to msgpack)
12+
13+
Data that was created on <v0.2.2 cannot be read by newer langgraph-checkpoint-mongodb.
14+
15+
Invoke using PEP 723 (Inline Script Metadata):
16+
`$ uv run scripts/migrate_checkpoints_to_typed_metadata.py -h`
17+
18+
Notes:
19+
- writes_collections is not in scope as it has always used serde.dumps_typed / serde.loads_typed
20+
21+
"""
22+
23+
import argparse
24+
import logging
25+
import multiprocessing as mp
26+
import time
27+
from typing import Any, Union
28+
29+
from bson.raw_bson import RawBSONDocument
30+
from langgraph.checkpoint.base import CheckpointMetadata
31+
from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer
32+
from pymongo import MongoClient
33+
from pymongo.collection import Collection
34+
from pymongo.errors import BulkWriteError
35+
from pymongo.typings import _DocumentType
36+
37+
from langgraph.checkpoint.mongodb import MongoDBSaver
38+
39+
serde = JsonPlusSerializer()
40+
41+
42+
def parse_args() -> argparse.Namespace:
43+
parser = argparse.ArgumentParser(
44+
description="Migrate langgraph checkpoint metadata to typed format (>= v0.2.2)."
45+
)
46+
47+
parser.add_argument(
48+
"--mongodb-uri",
49+
default="mongodb://localhost:27017/?directConnection=true",
50+
help="MongoDB connection URI",
51+
)
52+
53+
parser.add_argument(
54+
"--db",
55+
required=True,
56+
help="Database name containing checkpoint collections",
57+
)
58+
59+
parser.add_argument(
60+
"--collections",
61+
nargs="+",
62+
required=True,
63+
help="One or more checkpoint collection names to migrate",
64+
)
65+
66+
parser.add_argument(
67+
"--batch-size",
68+
type=int,
69+
default=1000,
70+
help="Number of documents per insert batch",
71+
)
72+
73+
parser.add_argument(
74+
"--suffix",
75+
default="-new",
76+
help="Suffix for migrated collections (default: -new)",
77+
)
78+
79+
parser.add_argument(
80+
"--workers",
81+
type=int,
82+
default=1,
83+
help="Number of worker processes (default: 1)",
84+
)
85+
86+
parser.add_argument(
87+
"--dry-run",
88+
action="store_true",
89+
help="Run migration without writing any data",
90+
)
91+
92+
parser.add_argument(
93+
"--clear-destination",
94+
action="store_true",
95+
help="Delete destination collection before migrating",
96+
)
97+
98+
parser.add_argument(
99+
"--log-level",
100+
default="INFO",
101+
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
102+
help="Logging verbosity",
103+
)
104+
105+
return parser.parse_args()
106+
107+
108+
def loads_metadata_orig(metadata: dict[str, Any]) -> CheckpointMetadata:
109+
if isinstance(metadata, dict):
110+
return {k: loads_metadata_orig(v) for k, v in metadata.items()}
111+
return serde.loads_typed(("json", metadata))
112+
113+
114+
def dumps_metadata_new(
115+
metadata: Union[CheckpointMetadata, Any],
116+
) -> Union[bytes, dict[str, Any]]:
117+
if isinstance(metadata, dict):
118+
return {k: dumps_metadata_new(v) for k, v in metadata.items()}
119+
return serde.dumps_typed(metadata)
120+
121+
122+
def insert_non_duplicates(
123+
clxn: Collection, buffer: list[Union[_DocumentType, RawBSONDocument]]
124+
) -> None:
125+
try:
126+
clxn.insert_many(buffer, ordered=False)
127+
except BulkWriteError as e:
128+
write_errors = e.details.get("writeErrors", [])
129+
non_dupe_errors = [err for err in write_errors if err.get("code") != 11000]
130+
if non_dupe_errors:
131+
raise
132+
finally:
133+
buffer.clear()
134+
135+
136+
def worker_migrate(
137+
worker_id: int,
138+
args: argparse.Namespace,
139+
collection_name: str,
140+
) -> dict[str, int]:
141+
"""
142+
Worker process that migrates a shard of documents determined by _id hash.
143+
"""
144+
client: MongoClient = MongoClient(args.mongodb_uri)
145+
db = client[args.db]
146+
147+
clxn_orig = db[collection_name]
148+
dest_name = f"{collection_name}{args.suffix}"
149+
clxn_new = db[dest_name]
150+
151+
scanned = 0
152+
migrated = 0
153+
buffer = []
154+
155+
cursor = clxn_orig.find({}, batch_size=args.batch_size)
156+
157+
for doc in cursor:
158+
# Deterministic partition
159+
if hash(doc["_id"]) % args.workers != worker_id:
160+
continue
161+
162+
scanned += 1
163+
164+
if "metadata" in doc:
165+
doc["metadata"] = dumps_metadata_new(loads_metadata_orig(doc["metadata"]))
166+
167+
buffer.append(doc)
168+
migrated += 1
169+
170+
if len(buffer) >= args.batch_size:
171+
if not args.dry_run:
172+
insert_non_duplicates(clxn_new, buffer)
173+
else:
174+
buffer.clear()
175+
176+
if buffer:
177+
if not args.dry_run:
178+
insert_non_duplicates(clxn_new, buffer)
179+
else:
180+
buffer.clear()
181+
182+
return {
183+
"scanned": scanned,
184+
"migrated": migrated,
185+
}
186+
187+
188+
def main() -> None:
189+
args = parse_args()
190+
191+
logging.basicConfig(
192+
level=getattr(logging, args.log_level),
193+
format="%(asctime)s [%(levelname)s] %(message)s",
194+
)
195+
196+
start_time = time.time()
197+
198+
logging.info("Beginning checkpoint data migration")
199+
logging.info(f"mongodb_uri={args.mongodb_uri}")
200+
logging.info(f"db={args.db}")
201+
logging.info(f"collections={args.collections}")
202+
logging.info(f"batch_size={args.batch_size}")
203+
logging.info(f"suffix={args.suffix}")
204+
logging.info(f"dry_run={args.dry_run}")
205+
206+
total_scanned = 0
207+
total_migrated = 0
208+
209+
for collection_name in args.collections:
210+
logging.info(f"--- Migrating collection: {collection_name} ---")
211+
212+
client: MongoClient = MongoClient(args.mongodb_uri)
213+
db = client[args.db]
214+
215+
clxn_orig = db[collection_name]
216+
dest_name = f"{collection_name}{args.suffix}"
217+
clxn_new = db[dest_name]
218+
219+
if args.clear_destination and not args.dry_run:
220+
logging.warning(f"Clearing destination collection {dest_name}")
221+
clxn_new.delete_many({})
222+
223+
n_orig = clxn_orig.count_documents({})
224+
logging.info(f"Source collection contains {n_orig} documents")
225+
226+
if n_orig == 0:
227+
logging.warning(f"Skipping empty or missing collection: {collection_name}")
228+
continue
229+
230+
if args.workers == 1:
231+
# Single-process fallback (existing behavior)
232+
result = worker_migrate(0, args, collection_name)
233+
total_scanned += result["scanned"]
234+
total_migrated += result["migrated"]
235+
else:
236+
logging.info(f"Starting {args.workers} workers")
237+
238+
with mp.Pool(processes=args.workers) as pool:
239+
results = pool.starmap(
240+
worker_migrate,
241+
[
242+
(worker_id, args, collection_name)
243+
for worker_id in range(args.workers)
244+
],
245+
)
246+
247+
for r in results:
248+
total_scanned += r["scanned"]
249+
total_migrated += r["migrated"]
250+
251+
if not args.dry_run:
252+
n_new = clxn_new.count_documents({})
253+
assert n_new == total_migrated or n_new <= n_orig
254+
255+
saver_new = MongoDBSaver(
256+
client=client,
257+
db_name=args.db,
258+
checkpoint_collection_name=dest_name,
259+
)
260+
261+
checkpoints_new = saver_new.list(config=None, limit=1)
262+
sample_thread = next(checkpoints_new).config["configurable"]["thread_id"]
263+
sample_checkpoint = saver_new.get_tuple(
264+
config={"configurable": {"thread_id": sample_thread}}
265+
)
266+
if sample_checkpoint is not None:
267+
logging.debug(
268+
f"[{collection_name}] Sample metadata: {sample_checkpoint.metadata}"
269+
)
270+
271+
elapsed = time.time() - start_time
272+
rate = total_migrated / elapsed if elapsed > 0 else 0
273+
274+
logging.info("=== Migration Summary ===")
275+
logging.info(f"Collections processed: {len(args.collections)}")
276+
logging.info(f"Documents scanned: {total_scanned}")
277+
logging.info(f"Documents migrated: {total_migrated}")
278+
logging.info(f"Elapsed time: {elapsed:.2f}s")
279+
logging.info(f"Throughput: {rate:.2f} docs/sec")
280+
281+
if args.dry_run:
282+
logging.info("Dry-run mode enabled: no data was written")
283+
284+
285+
if __name__ == "__main__":
286+
main()

0 commit comments

Comments
 (0)