Skip to content

Commit 3c75e41

Browse files
authored
fix(langgraph): do not apply pending writes when updating state (#6389)
PR #6195 fixed `bulk_update_state` to populate `task.result` by calling `prepare_next_tasks` to discover task IDs. Before #6195, prepare_next_tasks was gated by the condition `CONFIG_KEY_CHECKPOINT_ID not in config[CONF]` - so it only ran if we were resuming from an empty checkpoint. This check was removed in order to properly populate task results. However, the removal of this check inadvertently applied pending writes during manual state updates which caused issues when forking: - When you fork from a checkpoint by calling `update_state(config, new_values, as_node="mynode")`, pending writes from the original execution were being applied - This caused stale data to leak into forked threads (eg. old tool call results appearing in forked execution) Changes Removed pending writes application from `bulk_update_state` and `abulk_update_state`: - Still call `prepare_next_tasks` to discover task IDs, but skip the code that applies null writes and regular pending writes Tests - Added `test_fork_does_not_apply_pending_writes` for sync and async which verifies forking doesn't include stale pending writes from original execution
1 parent 575401c commit 3c75e41

File tree

3 files changed

+84
-53
lines changed

3 files changed

+84
-53
lines changed

libs/langgraph/langgraph/pregel/main.py

Lines changed: 6 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1702,7 +1702,9 @@ def perform_superstep(
17021702
# we use the task id generated by prepare_next_tasks
17031703
node_to_task_ids: dict[str, deque[str]] = defaultdict(deque)
17041704
if saved is not None and saved.pending_writes is not None:
1705-
# tasks for this checkpoint
1705+
# we call prepare_next_tasks to discover the task IDs that
1706+
# would have been generated, so we can reuse them and
1707+
# properly populate task.result in state history
17061708
next_tasks = prepare_next_tasks(
17071709
checkpoint,
17081710
saved.pending_writes,
@@ -1721,32 +1723,6 @@ def perform_superstep(
17211723
for t in next_tasks.values():
17221724
node_to_task_ids[t.name].append(t.id)
17231725

1724-
# apply null writes
1725-
if null_writes := [
1726-
w[1:] for w in saved.pending_writes or [] if w[0] == NULL_TASK_ID
1727-
]:
1728-
apply_writes(
1729-
checkpoint,
1730-
channels,
1731-
[PregelTaskWrites((), INPUT, null_writes, [])],
1732-
checkpointer.get_next_version,
1733-
self.trigger_to_nodes,
1734-
)
1735-
# apply writes
1736-
for tid, k, v in saved.pending_writes:
1737-
if k in (ERROR, INTERRUPT):
1738-
continue
1739-
if tid not in next_tasks:
1740-
continue
1741-
next_tasks[tid].writes.append((k, v))
1742-
if tasks := [t for t in next_tasks.values() if t.writes]:
1743-
apply_writes(
1744-
checkpoint,
1745-
channels,
1746-
tasks,
1747-
checkpointer.get_next_version,
1748-
self.trigger_to_nodes,
1749-
)
17501726
valid_updates: list[tuple[str, dict[str, Any] | None, str | None]] = []
17511727
if len(updates) == 1:
17521728
values, as_node, task_id = updates[0]
@@ -2167,7 +2143,9 @@ async def aperform_superstep(
21672143
# we use the task id generated by prepare_next_tasks
21682144
node_to_task_ids: dict[str, deque[str]] = defaultdict(deque)
21692145
if saved is not None and saved.pending_writes is not None:
2170-
# tasks for this checkpoint
2146+
# we call prepare_next_tasks to discover the task IDs that
2147+
# would have been generated, so we can reuse them and
2148+
# properly populate task.result in state history
21712149
next_tasks = prepare_next_tasks(
21722150
checkpoint,
21732151
saved.pending_writes,
@@ -2186,31 +2164,6 @@ async def aperform_superstep(
21862164
for t in next_tasks.values():
21872165
node_to_task_ids[t.name].append(t.id)
21882166

2189-
# apply null writes
2190-
if null_writes := [
2191-
w[1:] for w in saved.pending_writes or [] if w[0] == NULL_TASK_ID
2192-
]:
2193-
apply_writes(
2194-
checkpoint,
2195-
channels,
2196-
[PregelTaskWrites((), INPUT, null_writes, [])],
2197-
checkpointer.get_next_version,
2198-
self.trigger_to_nodes,
2199-
)
2200-
for tid, k, v in saved.pending_writes:
2201-
if k in (ERROR, INTERRUPT):
2202-
continue
2203-
if tid not in next_tasks:
2204-
continue
2205-
next_tasks[tid].writes.append((k, v))
2206-
if tasks := [t for t in next_tasks.values() if t.writes]:
2207-
apply_writes(
2208-
checkpoint,
2209-
channels,
2210-
tasks,
2211-
checkpointer.get_next_version,
2212-
self.trigger_to_nodes,
2213-
)
22142167
valid_updates: list[tuple[str, dict[str, Any] | None, str | None]] = []
22152168
if len(updates) == 1:
22162169
values, as_node, task_id = updates[0]

libs/langgraph/tests/test_pregel.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8805,3 +8805,43 @@ def node_c(state: State):
88058805
InvalidUpdateError, match="Can receive only one Overwrite value per super-step."
88068806
):
88078807
graph.invoke({"messages": ["START"]}, config)
8808+
8809+
8810+
def test_fork_does_not_apply_pending_writes(
8811+
sync_checkpointer: BaseCheckpointSaver,
8812+
) -> None:
8813+
"""Test that forking with update_state does not apply pending writes from original execution."""
8814+
8815+
class State(TypedDict):
8816+
value: Annotated[int, operator.add]
8817+
8818+
def node_a(state: State) -> State:
8819+
return {"value": 10}
8820+
8821+
def node_b(state: State) -> State:
8822+
return {"value": 100}
8823+
8824+
graph = (
8825+
StateGraph(State)
8826+
.add_node("node_a", node_a)
8827+
.add_node("node_b", node_b)
8828+
.add_edge(START, "node_a")
8829+
.add_edge("node_a", "node_b")
8830+
.compile(checkpointer=sync_checkpointer)
8831+
)
8832+
8833+
thread1 = {"configurable": {"thread_id": "1"}}
8834+
graph.invoke({"value": 1}, thread1)
8835+
8836+
history = list(graph.get_state_history(thread1))
8837+
checkpoint_before_a = next(s for s in history if s.next == ("node_a",))
8838+
8839+
fork_config = graph.update_state(
8840+
checkpoint_before_a.config, {"value": 20}, as_node="node_a"
8841+
)
8842+
8843+
# Continue from fork (should run node_b)
8844+
result = graph.invoke(None, fork_config)
8845+
8846+
# Should be: 1 (input) + 20 (forked node_a) + 100 (node_b) = 121
8847+
assert result == {"value": 121}

libs/langgraph/tests/test_pregel_async.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9249,3 +9249,41 @@ def first_task_result(history: list[StateSnapshot], node: str) -> Any:
92499249

92509250
assert bulk_start_result == ref_start_result == {"num": 1, "text": "one"}
92519251
assert bulk_double_result == ref_double_result == {"num": 2, "text": "oneone"}
9252+
9253+
9254+
async def test_fork_does_not_apply_pending_writes(
9255+
async_checkpointer: BaseCheckpointSaver,
9256+
) -> None:
9257+
"""Test that forking with aupdate_state does not apply pending writes from original execution."""
9258+
9259+
class State(TypedDict):
9260+
value: Annotated[int, operator.add]
9261+
9262+
def node_a(state: State) -> State:
9263+
return {"value": 10}
9264+
9265+
def node_b(state: State) -> State:
9266+
return {"value": 100}
9267+
9268+
graph = (
9269+
StateGraph(State)
9270+
.add_node("node_a", node_a)
9271+
.add_node("node_b", node_b)
9272+
.add_edge(START, "node_a")
9273+
.add_edge("node_a", "node_b")
9274+
.compile(checkpointer=async_checkpointer)
9275+
)
9276+
9277+
thread1 = {"configurable": {"thread_id": "1"}}
9278+
await graph.ainvoke({"value": 1}, thread1)
9279+
9280+
history = [c async for c in graph.aget_state_history(thread1)]
9281+
checkpoint_before_a = next(s for s in history if s.next == ("node_a",))
9282+
9283+
fork_config = await graph.aupdate_state(
9284+
checkpoint_before_a.config, {"value": 20}, as_node="node_a"
9285+
)
9286+
result = await graph.ainvoke(None, fork_config)
9287+
9288+
# 1 (input) + 20 (forked node_a) + 100 (node_b) = 121
9289+
assert result == {"value": 121}

0 commit comments

Comments
 (0)