From 42332376a059a310dd986e7a3dc3ba4ffde5233e Mon Sep 17 00:00:00 2001 From: Yash V Date: Fri, 3 Oct 2025 18:09:01 -0700 Subject: [PATCH] fix(core): remove orphaned ToolMessages in trim_message --- libs/core/langchain_core/messages/utils.py | 41 ++- .../tests/unit_tests/messages/test_utils.py | 267 ++++++++++++++++++ 2 files changed, 307 insertions(+), 1 deletion(-) diff --git a/libs/core/langchain_core/messages/utils.py b/libs/core/langchain_core/messages/utils.py index 7e2bd238c1b69..ad02ace45e079 100644 --- a/libs/core/langchain_core/messages/utils.py +++ b/libs/core/langchain_core/messages/utils.py @@ -1420,6 +1420,43 @@ def convert_to_openai_messages( return oai_messages +def _remove_orphaned_tool_messages(messages: list[BaseMessage]) -> list[BaseMessage]: + """Remove ToolMessages that don't have a corresponding AIMessage with tool_calls. + + When trimming messages, we may accidentally orphan ToolMessages by removing + the AIMessage that made the tool call. This function cleans up such orphans + to maintain valid message history. + + Args: + messages: List of messages to clean. + + Returns: + List of messages with orphaned ToolMessages removed. + """ + if not messages: + return messages + + # Build a set of valid tool_call_ids from AIMessages + valid_tool_call_ids: set[str] = set() + for msg in messages: + if isinstance(msg, AIMessage) and msg.tool_calls: + for tool_call in msg.tool_calls: + if tool_call_id := tool_call.get("id"): + valid_tool_call_ids.add(tool_call_id) + + # Filter out ToolMessages with invalid tool_call_ids + cleaned_messages: list[BaseMessage] = [] + for msg in messages: + if isinstance(msg, ToolMessage): + if msg.tool_call_id in valid_tool_call_ids: + cleaned_messages.append(msg) + # else: skip orphaned ToolMessage + else: + cleaned_messages.append(msg) + + return cleaned_messages + + def _first_max_tokens( messages: Sequence[BaseMessage], *, @@ -1535,7 +1572,9 @@ def _first_max_tokens( else: break - return messages[:idx] + trimmed = messages[:idx] + # Remove any orphaned ToolMessages that lost their corresponding AIMessage + return _remove_orphaned_tool_messages(trimmed) def _last_max_tokens( diff --git a/libs/core/tests/unit_tests/messages/test_utils.py b/libs/core/tests/unit_tests/messages/test_utils.py index 32151132b3d4c..fd9b7f6bc8d7c 100644 --- a/libs/core/tests/unit_tests/messages/test_utils.py +++ b/libs/core/tests/unit_tests/messages/test_utils.py @@ -1584,3 +1584,270 @@ def test_convert_to_openai_messages_reasoning_content() -> None: ], } assert mixed_result == expected_mixed + + +def test_trim_messages_removes_orphaned_tool_messages() -> None: + """Test that trim_messages removes orphaned ToolMessages. + + When the corresponding AIMessage with tool_calls is trimmed away, the + ToolMessage becomes orphaned and should be removed. + + This is the exact scenario from the bug report: + https://github.com/langchain-ai/langchain/issues/xxxxx + """ + messages = [ + HumanMessage("What's the weather in Florida?"), + AIMessage( + [ + {"type": "text", "text": "Let's check the weather in Florida"}, + { + "type": "tool_use", + "id": "abc123", + "name": "get_weather", + "input": {"location": "Florida"}, + }, + ], + tool_calls=[ + { + "name": "get_weather", + "args": {"location": "Florida"}, + "id": "abc123", + "type": "tool_call", + }, + ], + ), + ToolMessage( + "It's sunny.", + name="get_weather", + tool_call_id="abc123", + ), + HumanMessage("I see"), + AIMessage("Do you want to know anything else?"), + HumanMessage("No, thanks"), + AIMessage("You're welcome! Have a great day!"), + ] + + # Trim to last 5 messages (by count) + trimmed_messages = trim_messages( + messages, + strategy="last", + token_counter=len, + max_tokens=5, + ) + + # Should NOT include the orphaned ToolMessage + # The AIMessage with tool_calls was removed, so ToolMessage should be removed too + assert len(trimmed_messages) == 4 + assert isinstance(trimmed_messages[0], HumanMessage) + assert trimmed_messages[0].content == "I see" + assert isinstance(trimmed_messages[1], AIMessage) + assert trimmed_messages[1].content == "Do you want to know anything else?" + assert isinstance(trimmed_messages[2], HumanMessage) + assert trimmed_messages[2].content == "No, thanks" + assert isinstance(trimmed_messages[3], AIMessage) + assert trimmed_messages[3].content == "You're welcome! Have a great day!" + + # Verify no ToolMessages in result + assert not any(isinstance(msg, ToolMessage) for msg in trimmed_messages) + + +def test_trim_messages_preserves_valid_tool_calls() -> None: + """Test that valid tool call sequences are preserved when they fit in the budget.""" + messages = [ + HumanMessage("What's 2+2?"), + AIMessage( + "Let me calculate that", + tool_calls=[ + { + "name": "calculator", + "args": {"expression": "2+2"}, + "id": "calc1", + "type": "tool_call", + }, + ], + ), + ToolMessage("4", name="calculator", tool_call_id="calc1"), + AIMessage("The answer is 4"), + ] + + # Trim to include all messages + trimmed_messages = trim_messages( + messages, + strategy="last", + token_counter=len, + max_tokens=10, + ) + + # All messages should be preserved, including valid tool call sequence + assert len(trimmed_messages) == 4 + assert isinstance(trimmed_messages[0], HumanMessage) + assert isinstance(trimmed_messages[1], AIMessage) + assert trimmed_messages[1].tool_calls[0]["id"] == "calc1" + assert isinstance(trimmed_messages[2], ToolMessage) + assert trimmed_messages[2].tool_call_id == "calc1" + assert isinstance(trimmed_messages[3], AIMessage) + + +def test_trim_messages_multiple_tool_calls() -> None: + """Test handling of multiple tool calls in sequence.""" + messages = [ + HumanMessage("Get me weather and news"), + AIMessage( + "Fetching both...", + tool_calls=[ + { + "name": "get_weather", + "args": {"location": "NYC"}, + "id": "tool1", + "type": "tool_call", + }, + { + "name": "get_news", + "args": {"topic": "tech"}, + "id": "tool2", + "type": "tool_call", + }, + ], + ), + ToolMessage("Sunny", name="get_weather", tool_call_id="tool1"), + ToolMessage("AI news update", name="get_news", tool_call_id="tool2"), + HumanMessage("Thanks"), + AIMessage("You're welcome!"), + ] + + # Trim to last 2 messages (should remove all tool-related messages) + trimmed_messages = trim_messages( + messages, + strategy="last", + token_counter=len, + max_tokens=2, + ) + + assert len(trimmed_messages) == 2 + assert isinstance(trimmed_messages[0], HumanMessage) + assert trimmed_messages[0].content == "Thanks" + assert isinstance(trimmed_messages[1], AIMessage) + assert trimmed_messages[1].content == "You're welcome!" + # No ToolMessages should remain + assert not any(isinstance(msg, ToolMessage) for msg in trimmed_messages) + + +def test_trim_messages_partial_tool_orphaning() -> None: + """Test when some tool calls are preserved but others are orphaned.""" + messages = [ + HumanMessage("First question"), + AIMessage( + "Let me check", + tool_calls=[ + { + "name": "tool1", + "args": {}, + "id": "old_tool", + "type": "tool_call", + }, + ], + ), + ToolMessage("Result 1", name="tool1", tool_call_id="old_tool"), + HumanMessage("Second question"), + AIMessage( + "Checking again", + tool_calls=[ + { + "name": "tool2", + "args": {}, + "id": "new_tool", + "type": "tool_call", + }, + ], + ), + ToolMessage("Result 2", name="tool2", tool_call_id="new_tool"), + AIMessage("Done"), + ] + + # Trim to last 4 messages - should keep only the second tool call sequence + trimmed_messages = trim_messages( + messages, + strategy="last", + token_counter=len, + max_tokens=4, + ) + + assert len(trimmed_messages) == 4 + # Should have the second AIMessage with tool_calls + ai_msgs_with_tools = [ + msg for msg in trimmed_messages if isinstance(msg, AIMessage) and msg.tool_calls + ] + assert len(ai_msgs_with_tools) == 1 + assert ai_msgs_with_tools[0].tool_calls[0]["id"] == "new_tool" + + # Should have matching ToolMessage + tool_msgs = [msg for msg in trimmed_messages if isinstance(msg, ToolMessage)] + assert len(tool_msgs) == 1 + assert tool_msgs[0].tool_call_id == "new_tool" + + +def test_trim_messages_strategy_first_removes_orphans() -> None: + """Test that strategy='first' also removes orphaned ToolMessages.""" + messages = [ + HumanMessage("Start"), + AIMessage("Response"), + HumanMessage("Question"), + AIMessage( + "Tool time", + tool_calls=[ + { + "name": "tool", + "args": {}, + "id": "tool123", + "type": "tool_call", + }, + ], + ), + ToolMessage("Result", name="tool", tool_call_id="tool123"), + AIMessage("Final"), + ] + + # Trim to first 4 messages - cuts off in the middle of tool call sequence + trimmed_messages = trim_messages( + messages, + strategy="first", + token_counter=len, + max_tokens=4, + ) + + # The AIMessage with tool_calls is included but ToolMessage is cut off + # However, since the ToolMessage isn't in the trimmed result, no orphaning occurs + assert len(trimmed_messages) == 4 + assert isinstance(trimmed_messages[3], AIMessage) + + # Now test with an orphaned ToolMessage at the beginning + # Trim first 3 messages which includes an orphan + messages_with_orphan = [ + HumanMessage("A"), + ToolMessage("Orphaned", name="tool", tool_call_id="missing_id"), + HumanMessage("B"), + AIMessage( + "Using tool", + tool_calls=[ + { + "name": "valid_tool", + "args": {}, + "id": "valid123", + "type": "tool_call", + }, + ], + ), + ToolMessage("Valid result", name="valid_tool", tool_call_id="valid123"), + ] + + trimmed = trim_messages( + messages_with_orphan, + strategy="first", + token_counter=len, + max_tokens=3, + ) + + # Should remove the orphaned ToolMessage but keep valid messages + assert len(trimmed) == 2 # HumanMessage("A") and HumanMessage("B") + tool_messages = [msg for msg in trimmed if isinstance(msg, ToolMessage)] + assert len(tool_messages) == 0