|
| 1 | +import os |
| 2 | +from collections.abc import Generator |
| 3 | +from operator import add |
| 4 | +from typing import Annotated, Any, TypedDict |
| 5 | + |
| 6 | +import pytest |
| 7 | +from langchain_core.runnables import RunnableConfig |
| 8 | +from pymongo import MongoClient |
| 9 | +from typing_extensions import NotRequired |
| 10 | + |
| 11 | +from langgraph.checkpoint.mongodb import MongoDBSaver |
| 12 | +from langgraph.graph import END, START, StateGraph |
| 13 | +from langgraph.types import StateSnapshot |
| 14 | + |
| 15 | +# Test configuration |
| 16 | +MONGODB_URI = os.environ.get( |
| 17 | + "MONGODB_URI", "mongodb://127.0.0.1:27017?directConnection=true" |
| 18 | +) |
| 19 | + |
| 20 | + |
| 21 | +class ExpenseState(TypedDict): |
| 22 | + amount: NotRequired[int] |
| 23 | + version: NotRequired[int] |
| 24 | + approved: NotRequired[bool] |
| 25 | + messages: Annotated[list[str], add] |
| 26 | + |
| 27 | + |
| 28 | +def add_expense_node(state: ExpenseState) -> dict[str, Any]: |
| 29 | + """Node adds expense and a message""" |
| 30 | + return dict(amount=100, version=1, approved=False, messages=["Added new expense"]) |
| 31 | + |
| 32 | + |
| 33 | +def validate_expense_node(state: ExpenseState) -> dict[str, Any]: |
| 34 | + """Node that processes data based on current state""" |
| 35 | + if state.get("amount") == 200: |
| 36 | + return dict(approved=True, messages=["expense approved"]) |
| 37 | + else: |
| 38 | + return dict(approved=False, messages=["expense denied"]) |
| 39 | + |
| 40 | + |
| 41 | +@pytest.fixture( |
| 42 | + params=[None, 60 * 60], |
| 43 | + ids=["ttl_none", "ttl_3600"], |
| 44 | +) |
| 45 | +def checkpointer(request: Any) -> Generator[MongoDBSaver]: |
| 46 | + db_name = "langgraph_timetravel_db" |
| 47 | + checkpoint_collection_name = "checkpoints" |
| 48 | + writes_collection_name = "checkpoint_writes" |
| 49 | + |
| 50 | + # Initialize MongoDB checkpointer |
| 51 | + client: MongoClient = MongoClient(MONGODB_URI) |
| 52 | + |
| 53 | + # Clean up any existing test data. |
| 54 | + client.drop_database(db_name) |
| 55 | + |
| 56 | + saver = MongoDBSaver( |
| 57 | + client=client, |
| 58 | + db_name=db_name, |
| 59 | + collection_name=checkpoint_collection_name, |
| 60 | + WRITES_COLLECTION_NAME=writes_collection_name, |
| 61 | + ttl=request.param, |
| 62 | + ) |
| 63 | + |
| 64 | + # Can use this to compare |
| 65 | + # saver = InMemorySaver() |
| 66 | + |
| 67 | + yield saver |
| 68 | + |
| 69 | + client[db_name].drop_collection(checkpoint_collection_name) |
| 70 | + client[db_name].drop_collection(writes_collection_name) |
| 71 | + client.close() |
| 72 | + |
| 73 | + |
| 74 | +def test(checkpointer: MongoDBSaver) -> None: |
| 75 | + """Test ability to use checkpointer to update exact state of graph. |
| 76 | +
|
| 77 | + In this simple example, we assume an initial state has been set incorrectly. |
| 78 | + To fix this, instead of rerunning from start, |
| 79 | + we find the incorrect node, update_state, and continue (by passing None to invoke or stream). |
| 80 | +
|
| 81 | + This example does not use interrupt/resume as one might, for example, |
| 82 | + in an expense report approval workflow. |
| 83 | + """ |
| 84 | + initial_state: ExpenseState = dict( |
| 85 | + amount=0, version=0, approved=False, messages=["Initial state"] |
| 86 | + ) |
| 87 | + config: RunnableConfig = dict(configurable=dict(thread_id="test-time-travel")) |
| 88 | + |
| 89 | + # Create the graph, which should be a 2-step procedure |
| 90 | + workflow = StateGraph(ExpenseState) |
| 91 | + workflow.add_node("add_expense", add_expense_node) |
| 92 | + workflow.add_node("validate_expense", validate_expense_node) |
| 93 | + workflow.add_edge(START, "add_expense") |
| 94 | + workflow.add_edge("validate_expense", END) |
| 95 | + workflow.add_edge("add_expense", "validate_expense") |
| 96 | + graph = workflow.compile(checkpointer=checkpointer) |
| 97 | + |
| 98 | + # Run the graph |
| 99 | + graph.invoke(input=initial_state, config=config) # type:ignore[arg-type] |
| 100 | + |
| 101 | + # Check to see whether the final state is approved |
| 102 | + final_state = graph.get_state(config=config) |
| 103 | + |
| 104 | + # It is not approved. |
| 105 | + assert not final_state.values["approved"] |
| 106 | + |
| 107 | + # Let's use time-travel to find the checkpoint before "add_expense" |
| 108 | + checkpoints: list[StateSnapshot] = list(graph.get_state_history(config)) |
| 109 | + # checkpoints: list[CheckpointTuple] = list(checkpointer.list(config)) |
| 110 | + print(f"\nFound {len(checkpoints)} checkpoints") |
| 111 | + |
| 112 | + target_checkpoint = None |
| 113 | + for checkpoint in checkpoints: |
| 114 | + # Look for checkpoint after increment but before final processing |
| 115 | + if ( |
| 116 | + checkpoint.metadata and checkpoint.metadata.get("step") == 1 |
| 117 | + ): # Before validate node |
| 118 | + target_checkpoint = checkpoint |
| 119 | + break |
| 120 | + |
| 121 | + for state in checkpoints: |
| 122 | + if state.metadata: |
| 123 | + print(f"\n{state.metadata["step"]=}") |
| 124 | + print(f"{state.next=}") |
| 125 | + print(f"{state.config["configurable"]["checkpoint_id"]=}") |
| 126 | + print(f"{state.values=}") |
| 127 | + |
| 128 | + # Get state at that checkpoint |
| 129 | + assert target_checkpoint |
| 130 | + past_state = graph.get_state(target_checkpoint.config) |
| 131 | + |
| 132 | + # Update the expense amount to 200 that validate amounts |
| 133 | + updated_state = dict(**past_state.values) |
| 134 | + # updated_state = {} |
| 135 | + updated_state["amount"] = 200 |
| 136 | + updated_state["version"] = 2 |
| 137 | + updated_state["messages"] += ["Updated state"] |
| 138 | + |
| 139 | + updated_config = graph.update_state( |
| 140 | + config=target_checkpoint.config, values=updated_state |
| 141 | + ) |
| 142 | + |
| 143 | + # Continue from the checkpoint |
| 144 | + print("\nContinuing execution with stream(None, config)...") |
| 145 | + final_step = None |
| 146 | + for step in graph.stream(None, updated_config): |
| 147 | + print(f"Continuation step: {step}") |
| 148 | + final_step = step |
| 149 | + |
| 150 | + # Verify the final result |
| 151 | + assert isinstance(final_step, dict) |
| 152 | + assert final_step["validate_expense"]["approved"] |
| 153 | + # Note that all values are not in the final step |
| 154 | + assert "amount" not in final_step["validate_expense"] |
| 155 | + # They ARE available from graph.get_state |
| 156 | + final_state = graph.get_state(updated_config) |
| 157 | + assert final_state.values["amount"] == 200 |
| 158 | + assert set(final_state.values.keys()) == { |
| 159 | + "amount", |
| 160 | + "version", |
| 161 | + "messages", |
| 162 | + "approved", |
| 163 | + } |
0 commit comments