Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions libs/prebuilt/langgraph/prebuilt/tool_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -796,6 +796,18 @@ def _func(
# Construct ToolRuntime instances at the top level for each tool call
tool_runtimes = []
for call, cfg in zip(tool_calls, config_list, strict=False):
# Update checkpoint_ns to include tool_call_id.
# This ensures that if the tool is a subgraph, its events will be
# namespaced with the tool_call_id, allowing correlation in the UI.
if "configurable" not in cfg:
cfg["configurable"] = {}

parent_ns = cfg["configurable"].get("checkpoint_ns", "")
if parent_ns:
cfg["configurable"]["checkpoint_ns"] = f"{parent_ns}|{call['id']}"
else:
cfg["configurable"]["checkpoint_ns"] = call["id"]

state = self._extract_state(input)
tool_runtime = ToolRuntime(
state=state,
Expand Down Expand Up @@ -828,6 +840,18 @@ async def _afunc(
# Construct ToolRuntime instances at the top level for each tool call
tool_runtimes = []
for call, cfg in zip(tool_calls, config_list, strict=False):
# Update checkpoint_ns to include tool_call_id.
# This ensures that if the tool is a subgraph, its events will be
# namespaced with the tool_call_id, allowing correlation in the UI.
if "configurable" not in cfg:
cfg["configurable"] = {}

parent_ns = cfg["configurable"].get("checkpoint_ns", "")
if parent_ns:
cfg["configurable"]["checkpoint_ns"] = f"{parent_ns}|{call['id']}"
else:
cfg["configurable"]["checkpoint_ns"] = call["id"]

state = self._extract_state(input)
tool_runtime = ToolRuntime(
state=state,
Expand Down
258 changes: 258 additions & 0 deletions libs/prebuilt/tests/test_tool_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -1902,3 +1902,261 @@ def get_info(rt: ToolRuntime[MyContext]):
assert tool_message.type == "tool"
assert tool_message.content == "test_info"
assert tool_message.tool_call_id == "call_1"


async def test_tool_node_checkpoint_ns_includes_tool_call_id() -> None:
"""Test that checkpoint_ns in the config includes the tool_call_id when tools are invoked.

This verifies the fix that ensures the checkpoint_ns is updated to include the
tool_call_id for proper event namespacing in the UI.
"""
# Global variable to capture the config passed to the tool
captured_configs: list[dict] = []

@dec_tool
def config_capturing_tool(runtime: ToolRuntime) -> str:
"""Tool that captures the runtime config for verification."""
captured_configs.append(runtime.config.copy())
return f"checkpoint_ns: {runtime.config.get('configurable', {}).get('checkpoint_ns', 'NOT_SET')}"

@dec_tool
async def async_config_capturing_tool(runtime: ToolRuntime) -> str:
"""Async tool that captures the runtime config for verification."""
captured_configs.append(runtime.config.copy())
return f"checkpoint_ns: {runtime.config.get('configurable', {}).get('checkpoint_ns', 'NOT_SET')}"

tool_call_id = "test_tool_call_abc123"

# Test sync execution path
captured_configs.clear()
result = ToolNode([config_capturing_tool]).invoke(
{
"messages": [
AIMessage(
"test",
tool_calls=[
{
"name": "config_capturing_tool",
"args": {},
"id": tool_call_id,
}
],
)
]
},
config=_create_config_with_runtime(),
)

tool_message = result["messages"][-1]
assert tool_message.tool_call_id == tool_call_id
assert tool_message.content == f"checkpoint_ns: {tool_call_id}"

# Verify the captured config contains the tool_call_id in checkpoint_ns
assert len(captured_configs) == 1
captured_config = captured_configs[0]
checkpoint_ns = captured_config.get("configurable", {}).get("checkpoint_ns", "")
assert checkpoint_ns == tool_call_id, (
f"Expected checkpoint_ns to be '{tool_call_id}', got '{checkpoint_ns}'"
)

# Test async execution path
captured_configs.clear()
result_async = await ToolNode([async_config_capturing_tool]).ainvoke(
{
"messages": [
AIMessage(
"test",
tool_calls=[
{
"name": "async_config_capturing_tool",
"args": {},
"id": tool_call_id,
}
],
)
]
},
config=_create_config_with_runtime(),
)

tool_message_async = result_async["messages"][-1]
assert tool_message_async.tool_call_id == tool_call_id
assert tool_message_async.content == f"checkpoint_ns: {tool_call_id}"

# Verify the captured config contains the tool_call_id in checkpoint_ns
assert len(captured_configs) == 1
captured_config_async = captured_configs[0]
checkpoint_ns_async = captured_config_async.get("configurable", {}).get(
"checkpoint_ns", ""
)
assert checkpoint_ns_async == tool_call_id, (
f"Expected checkpoint_ns to be '{tool_call_id}', got '{checkpoint_ns_async}'"
)


async def test_tool_node_checkpoint_ns_with_parent_ns() -> None:
"""Test that checkpoint_ns appends tool_call_id when parent_ns already exists.

When a tool is invoked from within a subgraph and already has a parent_ns,
the checkpoint_ns should be formatted as "parent_ns|tool_call_id".
"""
captured_configs: list[dict] = []

@dec_tool
def nested_tool(runtime: ToolRuntime) -> str:
"""Tool that captures the runtime config for verification."""
captured_configs.append(runtime.config.copy())
return f"checkpoint_ns: {runtime.config.get('configurable', {}).get('checkpoint_ns', 'NOT_SET')}"

parent_ns = "parent_call_id_xyz"
tool_call_id = "child_tool_call_abc123"

# Test with existing parent_ns in the config
config = _create_config_with_runtime()
config["configurable"]["checkpoint_ns"] = parent_ns

captured_configs.clear()
result = ToolNode([nested_tool]).invoke(
{
"messages": [
AIMessage(
"test",
tool_calls=[
{
"name": "nested_tool",
"args": {},
"id": tool_call_id,
}
],
)
]
},
config=config,
)

tool_message = result["messages"][-1]
assert tool_message.tool_call_id == tool_call_id

# Verify the captured config contains the combined checkpoint_ns
assert len(captured_configs) == 1
captured_config = captured_configs[0]
checkpoint_ns = captured_config.get("configurable", {}).get("checkpoint_ns", "")
expected_ns = f"{parent_ns}|{tool_call_id}"
assert checkpoint_ns == expected_ns, (
f"Expected checkpoint_ns to be '{expected_ns}', got '{checkpoint_ns}'"
)


def test_tool_node_checkpoint_ns_sync_path() -> None:
"""Test that checkpoint_ns includes tool_call_id for sync path with tool_calls input."""
captured_configs: list[dict] = []

@dec_tool
def sync_capture_tool(runtime: ToolRuntime) -> str:
"""Sync tool that captures the runtime config."""
captured_configs.append(runtime.config.copy())
return "done"

tool_call_id = "sync_call_12345"

# Test with direct tool_calls list input
captured_configs.clear()
result = ToolNode([sync_capture_tool]).invoke(
[
{
"name": "sync_capture_tool",
"args": {},
"id": tool_call_id,
"type": "tool_call",
}
],
config=_create_config_with_runtime(),
)

assert result["messages"] == [
ToolMessage(
content="done", tool_call_id=tool_call_id, name="sync_capture_tool"
),
]

# Verify checkpoint_ns contains tool_call_id
assert len(captured_configs) == 1
checkpoint_ns = captured_configs[0].get("configurable", {}).get("checkpoint_ns", "")
assert checkpoint_ns == tool_call_id, (
f"Expected checkpoint_ns to be '{tool_call_id}', got '{checkpoint_ns}'"
)


async def test_tool_node_checkpoint_ns_multiple_tools() -> None:
"""Test checkpoint_ns for multiple concurrent tool calls.

Each tool call should have its own checkpoint_ns containing its specific tool_call_id.
"""
captured_configs: list[dict] = []

@dec_tool
def tool_a(runtime: ToolRuntime) -> str:
"""Tool A that captures its config."""
captured_configs.append({"name": "tool_a", "config": runtime.config.copy()})
return "a"

@dec_tool
def tool_b(runtime: ToolRuntime) -> str:
"""Tool B that captures its config."""
captured_configs.append({"name": "tool_b", "config": runtime.config.copy()})
return "b"

tool_call_id_a = "call_id_a_001"
tool_call_id_b = "call_id_b_002"

captured_configs.clear()
result = await ToolNode([tool_a, tool_b]).ainvoke(
{
"messages": [
AIMessage(
"test",
tool_calls=[
{
"name": "tool_a",
"args": {},
"id": tool_call_id_a,
},
{
"name": "tool_b",
"args": {},
"id": tool_call_id_b,
},
],
)
]
},
config=_create_config_with_runtime(),
)

# Verify both tools executed
assert len(result["messages"]) == 2
assert result["messages"][0].tool_call_id == tool_call_id_a
assert result["messages"][1].tool_call_id == tool_call_id_b

# Verify each tool has its own checkpoint_ns with its tool_call_id
assert len(captured_configs) == 2

tool_a_config = next((c for c in captured_configs if c["name"] == "tool_a"), None)
tool_b_config = next((c for c in captured_configs if c["name"] == "tool_b"), None)

assert tool_a_config is not None
assert tool_b_config is not None

checkpoint_ns_a = (
tool_a_config["config"].get("configurable", {}).get("checkpoint_ns", "")
)
checkpoint_ns_b = (
tool_b_config["config"].get("configurable", {}).get("checkpoint_ns", "")
)

assert checkpoint_ns_a == tool_call_id_a, (
f"Expected '{tool_call_id_a}', got '{checkpoint_ns_a}'"
)
assert checkpoint_ns_b == tool_call_id_b, (
f"Expected '{tool_call_id_b}', got '{checkpoint_ns_b}'"
)