Skip to content

Commit 6fb2297

Browse files
committed
fix(core): drop orphaned tool messages when trimming
1 parent 46b87e4 commit 6fb2297

File tree

2 files changed

+117
-6
lines changed

2 files changed

+117
-6
lines changed

libs/core/langchain_core/messages/utils.py

Lines changed: 39 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1440,11 +1440,10 @@ def _first_max_tokens(
14401440
# When all messages fit, only apply end_on filtering if needed
14411441
if end_on:
14421442
for _ in range(len(messages)):
1443-
if not _is_message_type(messages[-1], end_on):
1444-
messages.pop()
1445-
else:
1443+
if not messages or _is_message_type(messages[-1], end_on):
14461444
break
1447-
return messages
1445+
messages.pop()
1446+
return _remove_orphaned_tool_messages(messages)
14481447

14491448
# Use binary search to find the maximum number of messages within token limit
14501449
left, right = 0, len(messages)
@@ -1535,7 +1534,7 @@ def _first_max_tokens(
15351534
else:
15361535
break
15371536

1538-
return messages[:idx]
1537+
return _remove_orphaned_tool_messages(messages[:idx])
15391538

15401539

15411540
def _last_max_tokens(
@@ -1594,7 +1593,41 @@ def _last_max_tokens(
15941593
if system_message:
15951594
result = [system_message, *result]
15961595

1597-
return result
1596+
return _remove_orphaned_tool_messages(result)
1597+
1598+
1599+
def _remove_orphaned_tool_messages(
1600+
messages: Sequence[BaseMessage],
1601+
) -> list[BaseMessage]:
1602+
"""Drop tool messages whose corresponding tool calls are absent."""
1603+
if not messages:
1604+
return []
1605+
1606+
valid_tool_call_ids: set[str] = set()
1607+
for message in messages:
1608+
if isinstance(message, AIMessage):
1609+
if message.tool_calls:
1610+
for tool_call in message.tool_calls:
1611+
tool_call_id = tool_call.get("id")
1612+
if tool_call_id:
1613+
valid_tool_call_ids.add(tool_call_id)
1614+
if isinstance(message.content, list):
1615+
for block in message.content:
1616+
if (
1617+
isinstance(block, dict)
1618+
and block.get("type") == "tool_use"
1619+
and block.get("id")
1620+
):
1621+
valid_tool_call_ids.add(block["id"])
1622+
1623+
cleaned_messages: list[BaseMessage] = []
1624+
for message in messages:
1625+
if isinstance(message, ToolMessage) and (
1626+
not valid_tool_call_ids or message.tool_call_id not in valid_tool_call_ids
1627+
):
1628+
continue
1629+
cleaned_messages.append(message)
1630+
return cleaned_messages
15981631

15991632

16001633
_MSG_CHUNK_MAP: dict[type[BaseMessage], type[BaseMessageChunk]] = {

libs/core/tests/unit_tests/messages/test_utils.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,84 @@ def test_trim_messages_last_30_include_system_allow_partial_end_on_human() -> No
393393
assert _MESSAGES_TO_TRIM == _MESSAGES_TO_TRIM_COPY
394394

395395

396+
def test_trim_messages_last_removes_orphaned_tool_message() -> None:
397+
messages = [
398+
HumanMessage("What's the weather in Florida?"),
399+
AIMessage(
400+
[
401+
{"type": "text", "text": "Let's check the weather in Florida"},
402+
{
403+
"type": "tool_use",
404+
"id": "abc123",
405+
"name": "get_weather",
406+
"input": {"location": "Florida"},
407+
},
408+
],
409+
tool_calls=[
410+
{
411+
"name": "get_weather",
412+
"args": {"location": "Florida"},
413+
"id": "abc123",
414+
"type": "tool_call",
415+
}
416+
],
417+
),
418+
ToolMessage("It's sunny.", name="get_weather", tool_call_id="abc123"),
419+
HumanMessage("I see"),
420+
AIMessage("Do you want to know anything else?"),
421+
HumanMessage("No, thanks"),
422+
AIMessage("You're welcome! Have a great day!"),
423+
]
424+
425+
trimmed = trim_messages(
426+
messages,
427+
strategy="last",
428+
token_counter=len,
429+
max_tokens=5,
430+
)
431+
432+
expected = [
433+
HumanMessage("I see"),
434+
AIMessage("Do you want to know anything else?"),
435+
HumanMessage("No, thanks"),
436+
AIMessage("You're welcome! Have a great day!"),
437+
]
438+
439+
assert trimmed == expected
440+
assert _MESSAGES_TO_TRIM == _MESSAGES_TO_TRIM_COPY
441+
442+
443+
def test_trim_messages_last_preserves_tool_message_when_call_present() -> None:
444+
messages = [
445+
HumanMessage("Start"),
446+
AIMessage(
447+
"Sure, let me check",
448+
tool_calls=[
449+
{
450+
"name": "search",
451+
"args": {"query": "status"},
452+
"id": "tool-1",
453+
"type": "tool_call",
454+
}
455+
],
456+
),
457+
ToolMessage("All systems operational", tool_call_id="tool-1"),
458+
HumanMessage("Thanks"),
459+
]
460+
461+
trimmed = trim_messages(
462+
messages,
463+
strategy="last",
464+
token_counter=len,
465+
max_tokens=3,
466+
)
467+
468+
expected = messages[1:]
469+
470+
assert trimmed == expected
471+
assert _MESSAGES_TO_TRIM == _MESSAGES_TO_TRIM_COPY
472+
473+
396474
def test_trim_messages_last_40_include_system_allow_partial_start_on_human() -> None:
397475
expected = [
398476
SystemMessage("This is a 4 token text."),

0 commit comments

Comments
 (0)