Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions libs/aws/langchain_aws/chat_models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,10 @@ def _merge_messages(
]
)
last = merged[-1] if merged else None
if isinstance(last, HumanMessage) and isinstance(curr, HumanMessage):
if any(
all(isinstance(m, c) for m in (curr, last))
for c in (SystemMessage, HumanMessage)
):
if isinstance(last.content, str):
new_content: List = [{"type": "text", "text": last.content}]
else:
Expand All @@ -387,9 +390,11 @@ def _format_anthropic_messages(
merged_messages = _merge_messages(messages)
for i, message in enumerate(merged_messages):
if message.type == "system":
if i != 0:
raise ValueError("System message must be at beginning of message list.")
if isinstance(message.content, str):
if system is not None:
raise ValueError(
"Received multiple non-consecutive system messages."
)
elif isinstance(message.content, str):
system = message.content
elif isinstance(message.content, list):
system_blocks = []
Expand Down
4 changes: 2 additions & 2 deletions libs/aws/langchain_aws/chat_models/bedrock_converse.py
Original file line number Diff line number Diff line change
Expand Up @@ -1037,8 +1037,8 @@ def _messages_to_bedrock(
"""Handle Bedrock converse and Anthropic style content blocks"""
bedrock_messages: List[Dict[str, Any]] = []
bedrock_system: List[Dict[str, Any]] = []
# Merge system, human, ai message runs because Anthropic expects (at most) 1
# system message then alternating human/ai messages.
# Merge system, human, ai message runs because Anthropic expects
# (optional) system messages first, then alternating human/ai messages.
messages = merge_message_runs(messages)
for msg in messages:
content = _lc_content_to_bedrock(msg.content)
Expand Down
36 changes: 35 additions & 1 deletion libs/aws/tests/unit_tests/chat_models/test_bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
def test__merge_messages() -> None:
messages = [
SystemMessage("foo"), # type: ignore[misc]
SystemMessage("barfoo"), # type: ignore[misc]
HumanMessage("bar"), # type: ignore[misc]
AIMessage( # type: ignore[misc]
[
Expand All @@ -52,7 +53,12 @@ def test__merge_messages() -> None:
HumanMessage("next thing"), # type: ignore[misc]
]
expected = [
SystemMessage("foo"), # type: ignore[misc]
SystemMessage(
[
{'type': 'text', 'text': 'foo'},
{'type': 'text', 'text': 'barfoo'}
]
), # type: ignore[misc]
HumanMessage("bar"), # type: ignore[misc]
AIMessage( # type: ignore[misc]
[
Expand Down Expand Up @@ -345,6 +351,34 @@ def test__format_anthropic_messages_system_message_list_content() -> None:
actual = _format_anthropic_messages(messages)
assert expected == actual

def test__format_anthropic_multiple_system_messages() -> None:
"""Test that multiple system messages can be passed, and that none of them are required to be at position 0."""
system1 = SystemMessage("foo") # type: ignore[misc]
system2 = SystemMessage("bar") # type: ignore[misc]
human = HumanMessage("Hello!")
messages = [human, system1, system2]
expected_system = [
{'text': 'foo', 'type': 'text'},
{'text': 'bar', 'type': 'text'}
]
expected_messages = [
{"role": "user", "content": "Hello!"}
]

actual_system, actual_messages = _format_anthropic_messages(messages)
assert expected_system == actual_system
assert expected_messages == actual_messages

def test__format_anthropic_nonconsecutive_system_messages() -> None:
"""Test that we fail when non-consecutive system messages are passed."""
system1 = SystemMessage("foo") # type: ignore[misc]
system2 = SystemMessage("bar") # type: ignore[misc]
human = HumanMessage("Hello!")
messages = [system1, human, system2]

with pytest.raises(ValueError, match="Received multiple non-consecutive system messages."):
_format_anthropic_messages(messages)


@pytest.fixture()
def pydantic() -> Type[BaseModel]:
Expand Down