Skip to content

Commit 7453adb

Browse files
authored
fix(vertexai): allow system messages at any position in ChatAnthropicVertex (#1610)
1 parent 84bad66 commit 7453adb

File tree

2 files changed

+68
-8
lines changed

2 files changed

+68
-8
lines changed

libs/vertexai/langchain_google_vertexai/_anthropic_utils.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -267,10 +267,10 @@ def _format_messages_anthropic(
267267
formatted_messages: list[dict] = []
268268

269269
merged_messages = _merge_messages(messages)
270-
for i, message in enumerate(merged_messages):
270+
for message in merged_messages:
271271
if message.type == "system":
272-
if i != 0:
273-
msg = "System message must be at beginning of message list."
272+
if system_messages is not None:
273+
msg = "Received multiple non-consecutive system messages."
274274
raise ValueError(msg)
275275
fm = _format_message_anthropic(message, project)
276276
if fm:
@@ -458,16 +458,21 @@ def _merge_messages(
458458
curr = curr.model_copy(deep=True)
459459
curr.content = cleaned_content
460460
last = merged[-1] if merged else None
461-
if isinstance(last, HumanMessage) and isinstance(curr, HumanMessage):
462-
if isinstance(last.content, str):
463-
new_content: list = [{"type": "text", "text": last.content}]
461+
if any(
462+
all(isinstance(m, c) for m in (curr, last))
463+
for c in (SystemMessage, HumanMessage)
464+
):
465+
if isinstance(cast("BaseMessage", last).content, str):
466+
new_content: list = [
467+
{"type": "text", "text": cast("BaseMessage", last).content}
468+
]
464469
else:
465-
new_content = last.content
470+
new_content = cast("list", cast("BaseMessage", last).content)
466471
if isinstance(curr.content, str):
467472
new_content.append({"type": "text", "text": curr.content})
468473
else:
469474
new_content.extend(curr.content)
470-
last.content = new_content
475+
merged[-1] = curr.model_copy(update={"content": new_content})
471476
else:
472477
merged.append(curr)
473478
return merged

libs/vertexai/tests/unit_tests/test_anthropic_utils.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1497,3 +1497,58 @@ def test_format_image(image_url: str, expected_media_type: str) -> None:
14971497
}
14981498

14991499
mock_loader_instance.load_bytes.assert_called_once_with(image_url)
1500+
1501+
1502+
def test_format_messages_anthropic_system_not_first() -> None:
1503+
"""Test that system messages are accepted when not at position 0.
1504+
1505+
Regression test for https://github.com/langchain-ai/langchain-google/issues/1022
1506+
"""
1507+
messages = [
1508+
HumanMessage(content="Hello"),
1509+
AIMessage(content="Hi there!"),
1510+
SystemMessage(content="You are a helpful assistant."),
1511+
HumanMessage(content="What is 2+2?"),
1512+
]
1513+
system_messages, formatted_messages = _format_messages_anthropic(
1514+
messages, project="test-project"
1515+
)
1516+
1517+
assert system_messages == [{"type": "text", "text": "You are a helpful assistant."}]
1518+
# System message should be extracted; remaining messages should be formatted
1519+
assert len(formatted_messages) == 3
1520+
assert formatted_messages[0]["role"] == "user"
1521+
assert formatted_messages[1]["role"] == "assistant"
1522+
assert formatted_messages[2]["role"] == "user"
1523+
1524+
1525+
def test_format_messages_anthropic_consecutive_system_merged() -> None:
1526+
"""Test that consecutive system messages are merged into one."""
1527+
messages = [
1528+
SystemMessage(content="Rule 1."),
1529+
SystemMessage(content="Rule 2."),
1530+
HumanMessage(content="Hello"),
1531+
]
1532+
system_messages, formatted_messages = _format_messages_anthropic(
1533+
messages, project="test-project"
1534+
)
1535+
1536+
# Consecutive system messages should be merged
1537+
assert system_messages == [
1538+
{"type": "text", "text": "Rule 1."},
1539+
{"type": "text", "text": "Rule 2."},
1540+
]
1541+
assert len(formatted_messages) == 1
1542+
assert formatted_messages[0]["role"] == "user"
1543+
1544+
1545+
def test_format_messages_anthropic_multiple_non_consecutive_system_raises() -> None:
1546+
"""Test that multiple non-consecutive system messages raise an error."""
1547+
messages = [
1548+
SystemMessage(content="First system."),
1549+
HumanMessage(content="Hello"),
1550+
SystemMessage(content="Second system."),
1551+
HumanMessage(content="World"),
1552+
]
1553+
with pytest.raises(ValueError, match="multiple non-consecutive system messages"):
1554+
_format_messages_anthropic(messages, project="test-project")

0 commit comments

Comments
 (0)