Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 40 additions & 1 deletion libs/core/langchain_core/messages/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1420,6 +1420,43 @@ def convert_to_openai_messages(
return oai_messages


def _remove_orphaned_tool_messages(messages: list[BaseMessage]) -> list[BaseMessage]:
"""Remove ToolMessages that don't have a corresponding AIMessage with tool_calls.

When trimming messages, we may accidentally orphan ToolMessages by removing
the AIMessage that made the tool call. This function cleans up such orphans
to maintain valid message history.

Args:
messages: List of messages to clean.

Returns:
List of messages with orphaned ToolMessages removed.
"""
if not messages:
return messages

# Build a set of valid tool_call_ids from AIMessages
valid_tool_call_ids: set[str] = set()
for msg in messages:
if isinstance(msg, AIMessage) and msg.tool_calls:
for tool_call in msg.tool_calls:
if tool_call_id := tool_call.get("id"):
valid_tool_call_ids.add(tool_call_id)

# Filter out ToolMessages with invalid tool_call_ids
cleaned_messages: list[BaseMessage] = []
for msg in messages:
if isinstance(msg, ToolMessage):
if msg.tool_call_id in valid_tool_call_ids:
cleaned_messages.append(msg)
# else: skip orphaned ToolMessage
else:
cleaned_messages.append(msg)

return cleaned_messages


def _first_max_tokens(
messages: Sequence[BaseMessage],
*,
Expand Down Expand Up @@ -1535,7 +1572,9 @@ def _first_max_tokens(
else:
break

return messages[:idx]
trimmed = messages[:idx]
# Remove any orphaned ToolMessages that lost their corresponding AIMessage
return _remove_orphaned_tool_messages(trimmed)


def _last_max_tokens(
Expand Down
267 changes: 267 additions & 0 deletions libs/core/tests/unit_tests/messages/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1584,3 +1584,270 @@ def test_convert_to_openai_messages_reasoning_content() -> None:
],
}
assert mixed_result == expected_mixed


def test_trim_messages_removes_orphaned_tool_messages() -> None:
"""Test that trim_messages removes orphaned ToolMessages.

When the corresponding AIMessage with tool_calls is trimmed away, the
ToolMessage becomes orphaned and should be removed.

This is the exact scenario from the bug report:
https://github.com/langchain-ai/langchain/issues/xxxxx
"""
messages = [
HumanMessage("What's the weather in Florida?"),
AIMessage(
[
{"type": "text", "text": "Let's check the weather in Florida"},
{
"type": "tool_use",
"id": "abc123",
"name": "get_weather",
"input": {"location": "Florida"},
},
],
tool_calls=[
{
"name": "get_weather",
"args": {"location": "Florida"},
"id": "abc123",
"type": "tool_call",
},
],
),
ToolMessage(
"It's sunny.",
name="get_weather",
tool_call_id="abc123",
),
HumanMessage("I see"),
AIMessage("Do you want to know anything else?"),
HumanMessage("No, thanks"),
AIMessage("You're welcome! Have a great day!"),
]

# Trim to last 5 messages (by count)
trimmed_messages = trim_messages(
messages,
strategy="last",
token_counter=len,
max_tokens=5,
)

# Should NOT include the orphaned ToolMessage
# The AIMessage with tool_calls was removed, so ToolMessage should be removed too
assert len(trimmed_messages) == 4
assert isinstance(trimmed_messages[0], HumanMessage)
assert trimmed_messages[0].content == "I see"
assert isinstance(trimmed_messages[1], AIMessage)
assert trimmed_messages[1].content == "Do you want to know anything else?"
assert isinstance(trimmed_messages[2], HumanMessage)
assert trimmed_messages[2].content == "No, thanks"
assert isinstance(trimmed_messages[3], AIMessage)
assert trimmed_messages[3].content == "You're welcome! Have a great day!"

# Verify no ToolMessages in result
assert not any(isinstance(msg, ToolMessage) for msg in trimmed_messages)


def test_trim_messages_preserves_valid_tool_calls() -> None:
"""Test that valid tool call sequences are preserved when they fit in the budget."""
messages = [
HumanMessage("What's 2+2?"),
AIMessage(
"Let me calculate that",
tool_calls=[
{
"name": "calculator",
"args": {"expression": "2+2"},
"id": "calc1",
"type": "tool_call",
},
],
),
ToolMessage("4", name="calculator", tool_call_id="calc1"),
AIMessage("The answer is 4"),
]

# Trim to include all messages
trimmed_messages = trim_messages(
messages,
strategy="last",
token_counter=len,
max_tokens=10,
)

# All messages should be preserved, including valid tool call sequence
assert len(trimmed_messages) == 4
assert isinstance(trimmed_messages[0], HumanMessage)
assert isinstance(trimmed_messages[1], AIMessage)
assert trimmed_messages[1].tool_calls[0]["id"] == "calc1"
assert isinstance(trimmed_messages[2], ToolMessage)
assert trimmed_messages[2].tool_call_id == "calc1"
assert isinstance(trimmed_messages[3], AIMessage)


def test_trim_messages_multiple_tool_calls() -> None:
"""Test handling of multiple tool calls in sequence."""
messages = [
HumanMessage("Get me weather and news"),
AIMessage(
"Fetching both...",
tool_calls=[
{
"name": "get_weather",
"args": {"location": "NYC"},
"id": "tool1",
"type": "tool_call",
},
{
"name": "get_news",
"args": {"topic": "tech"},
"id": "tool2",
"type": "tool_call",
},
],
),
ToolMessage("Sunny", name="get_weather", tool_call_id="tool1"),
ToolMessage("AI news update", name="get_news", tool_call_id="tool2"),
HumanMessage("Thanks"),
AIMessage("You're welcome!"),
]

# Trim to last 2 messages (should remove all tool-related messages)
trimmed_messages = trim_messages(
messages,
strategy="last",
token_counter=len,
max_tokens=2,
)

assert len(trimmed_messages) == 2
assert isinstance(trimmed_messages[0], HumanMessage)
assert trimmed_messages[0].content == "Thanks"
assert isinstance(trimmed_messages[1], AIMessage)
assert trimmed_messages[1].content == "You're welcome!"
# No ToolMessages should remain
assert not any(isinstance(msg, ToolMessage) for msg in trimmed_messages)


def test_trim_messages_partial_tool_orphaning() -> None:
"""Test when some tool calls are preserved but others are orphaned."""
messages = [
HumanMessage("First question"),
AIMessage(
"Let me check",
tool_calls=[
{
"name": "tool1",
"args": {},
"id": "old_tool",
"type": "tool_call",
},
],
),
ToolMessage("Result 1", name="tool1", tool_call_id="old_tool"),
HumanMessage("Second question"),
AIMessage(
"Checking again",
tool_calls=[
{
"name": "tool2",
"args": {},
"id": "new_tool",
"type": "tool_call",
},
],
),
ToolMessage("Result 2", name="tool2", tool_call_id="new_tool"),
AIMessage("Done"),
]

# Trim to last 4 messages - should keep only the second tool call sequence
trimmed_messages = trim_messages(
messages,
strategy="last",
token_counter=len,
max_tokens=4,
)

assert len(trimmed_messages) == 4
# Should have the second AIMessage with tool_calls
ai_msgs_with_tools = [
msg for msg in trimmed_messages if isinstance(msg, AIMessage) and msg.tool_calls
]
assert len(ai_msgs_with_tools) == 1
assert ai_msgs_with_tools[0].tool_calls[0]["id"] == "new_tool"

# Should have matching ToolMessage
tool_msgs = [msg for msg in trimmed_messages if isinstance(msg, ToolMessage)]
assert len(tool_msgs) == 1
assert tool_msgs[0].tool_call_id == "new_tool"


def test_trim_messages_strategy_first_removes_orphans() -> None:
"""Test that strategy='first' also removes orphaned ToolMessages."""
messages = [
HumanMessage("Start"),
AIMessage("Response"),
HumanMessage("Question"),
AIMessage(
"Tool time",
tool_calls=[
{
"name": "tool",
"args": {},
"id": "tool123",
"type": "tool_call",
},
],
),
ToolMessage("Result", name="tool", tool_call_id="tool123"),
AIMessage("Final"),
]

# Trim to first 4 messages - cuts off in the middle of tool call sequence
trimmed_messages = trim_messages(
messages,
strategy="first",
token_counter=len,
max_tokens=4,
)

# The AIMessage with tool_calls is included but ToolMessage is cut off
# However, since the ToolMessage isn't in the trimmed result, no orphaning occurs
assert len(trimmed_messages) == 4
assert isinstance(trimmed_messages[3], AIMessage)

# Now test with an orphaned ToolMessage at the beginning
# Trim first 3 messages which includes an orphan
messages_with_orphan = [
HumanMessage("A"),
ToolMessage("Orphaned", name="tool", tool_call_id="missing_id"),
HumanMessage("B"),
AIMessage(
"Using tool",
tool_calls=[
{
"name": "valid_tool",
"args": {},
"id": "valid123",
"type": "tool_call",
},
],
),
ToolMessage("Valid result", name="valid_tool", tool_call_id="valid123"),
]

trimmed = trim_messages(
messages_with_orphan,
strategy="first",
token_counter=len,
max_tokens=3,
)

# Should remove the orphaned ToolMessage but keep valid messages
assert len(trimmed) == 2 # HumanMessage("A") and HumanMessage("B")
tool_messages = [msg for msg in trimmed if isinstance(msg, ToolMessage)]
assert len(tool_messages) == 0
Loading