From 6fb2297f6d3df56de8f8c398a8ed29205e09d033 Mon Sep 17 00:00:00 2001 From: Navpreet Singh Date: Sat, 4 Oct 2025 18:41:46 +0100 Subject: [PATCH] fix(core): drop orphaned tool messages when trimming --- libs/core/langchain_core/messages/utils.py | 45 +++++++++-- .../tests/unit_tests/messages/test_utils.py | 78 +++++++++++++++++++ 2 files changed, 117 insertions(+), 6 deletions(-) diff --git a/libs/core/langchain_core/messages/utils.py b/libs/core/langchain_core/messages/utils.py index 7e2bd238c1b69..6814ff4077374 100644 --- a/libs/core/langchain_core/messages/utils.py +++ b/libs/core/langchain_core/messages/utils.py @@ -1440,11 +1440,10 @@ def _first_max_tokens( # When all messages fit, only apply end_on filtering if needed if end_on: for _ in range(len(messages)): - if not _is_message_type(messages[-1], end_on): - messages.pop() - else: + if not messages or _is_message_type(messages[-1], end_on): break - return messages + messages.pop() + return _remove_orphaned_tool_messages(messages) # Use binary search to find the maximum number of messages within token limit left, right = 0, len(messages) @@ -1535,7 +1534,7 @@ def _first_max_tokens( else: break - return messages[:idx] + return _remove_orphaned_tool_messages(messages[:idx]) def _last_max_tokens( @@ -1594,7 +1593,41 @@ def _last_max_tokens( if system_message: result = [system_message, *result] - return result + return _remove_orphaned_tool_messages(result) + + +def _remove_orphaned_tool_messages( + messages: Sequence[BaseMessage], +) -> list[BaseMessage]: + """Drop tool messages whose corresponding tool calls are absent.""" + if not messages: + return [] + + valid_tool_call_ids: set[str] = set() + for message in messages: + if isinstance(message, AIMessage): + if message.tool_calls: + for tool_call in message.tool_calls: + tool_call_id = tool_call.get("id") + if tool_call_id: + valid_tool_call_ids.add(tool_call_id) + if isinstance(message.content, list): + for block in message.content: + if ( + isinstance(block, dict) + and block.get("type") == "tool_use" + and block.get("id") + ): + valid_tool_call_ids.add(block["id"]) + + cleaned_messages: list[BaseMessage] = [] + for message in messages: + if isinstance(message, ToolMessage) and ( + not valid_tool_call_ids or message.tool_call_id not in valid_tool_call_ids + ): + continue + cleaned_messages.append(message) + return cleaned_messages _MSG_CHUNK_MAP: dict[type[BaseMessage], type[BaseMessageChunk]] = { diff --git a/libs/core/tests/unit_tests/messages/test_utils.py b/libs/core/tests/unit_tests/messages/test_utils.py index 32151132b3d4c..93e8e0fc9c9c1 100644 --- a/libs/core/tests/unit_tests/messages/test_utils.py +++ b/libs/core/tests/unit_tests/messages/test_utils.py @@ -393,6 +393,84 @@ def test_trim_messages_last_30_include_system_allow_partial_end_on_human() -> No assert _MESSAGES_TO_TRIM == _MESSAGES_TO_TRIM_COPY +def test_trim_messages_last_removes_orphaned_tool_message() -> None: + messages = [ + HumanMessage("What's the weather in Florida?"), + AIMessage( + [ + {"type": "text", "text": "Let's check the weather in Florida"}, + { + "type": "tool_use", + "id": "abc123", + "name": "get_weather", + "input": {"location": "Florida"}, + }, + ], + tool_calls=[ + { + "name": "get_weather", + "args": {"location": "Florida"}, + "id": "abc123", + "type": "tool_call", + } + ], + ), + ToolMessage("It's sunny.", name="get_weather", tool_call_id="abc123"), + HumanMessage("I see"), + AIMessage("Do you want to know anything else?"), + HumanMessage("No, thanks"), + AIMessage("You're welcome! Have a great day!"), + ] + + trimmed = trim_messages( + messages, + strategy="last", + token_counter=len, + max_tokens=5, + ) + + expected = [ + HumanMessage("I see"), + AIMessage("Do you want to know anything else?"), + HumanMessage("No, thanks"), + AIMessage("You're welcome! Have a great day!"), + ] + + assert trimmed == expected + assert _MESSAGES_TO_TRIM == _MESSAGES_TO_TRIM_COPY + + +def test_trim_messages_last_preserves_tool_message_when_call_present() -> None: + messages = [ + HumanMessage("Start"), + AIMessage( + "Sure, let me check", + tool_calls=[ + { + "name": "search", + "args": {"query": "status"}, + "id": "tool-1", + "type": "tool_call", + } + ], + ), + ToolMessage("All systems operational", tool_call_id="tool-1"), + HumanMessage("Thanks"), + ] + + trimmed = trim_messages( + messages, + strategy="last", + token_counter=len, + max_tokens=3, + ) + + expected = messages[1:] + + assert trimmed == expected + assert _MESSAGES_TO_TRIM == _MESSAGES_TO_TRIM_COPY + + def test_trim_messages_last_40_include_system_allow_partial_start_on_human() -> None: expected = [ SystemMessage("This is a 4 token text."),