Skip to content

Commit 54aa053

Browse files
committed
fmt
1 parent 3d315b8 commit 54aa053

File tree

2 files changed

+44
-6
lines changed

2 files changed

+44
-6
lines changed

libs/aws/langchain_aws/chat_models/bedrock_converse.py

Lines changed: 43 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -493,6 +493,7 @@ def _messages_to_bedrock(
493493
if isinstance(msg, HumanMessage):
494494
bedrock_messages.append({"role": "user", "content": content})
495495
elif isinstance(msg, AIMessage):
496+
content = _upsert_tool_calls_to_bedrock_content(content, msg.tool_calls)
496497
bedrock_messages.append({"role": "assistant", "content": content})
497498
elif isinstance(msg, SystemMessage):
498499
if isinstance(msg.content, str):
@@ -514,7 +515,7 @@ def _messages_to_bedrock(
514515

515516
# TODO: Add status once we have ToolMessage.status support.
516517
curr["content"].append(
517-
{"toolResult": {"content": content, "toolUseID": msg.tool_call_id}}
518+
{"toolResult": {"content": content, "toolUseId": msg.tool_call_id}}
518519
)
519520
bedrock_messages.append(curr)
520521
else:
@@ -529,7 +530,7 @@ def _parse_response(response: Dict[str, Any]) -> AIMessage:
529530
tool_calls = _extract_tool_calls(anthropic_content)
530531
usage = UsageMetadata(_camel_to_snake_keys(response.pop("usage"))) # type: ignore[misc]
531532
return AIMessage(
532-
content=anthropic_content, # type: ignore[arg-type]
533+
content=_str_if_single_text_block(anthropic_content), # type: ignore[arg-type]
533534
usage_metadata=usage,
534535
response_metadata=response,
535536
tool_calls=tool_calls,
@@ -602,7 +603,7 @@ def _anthropic_to_bedrock(
602603
content: Union[str, List[Union[str, Dict[str, Any]]]],
603604
) -> List[Dict[str, Any]]:
604605
if isinstance(content, str):
605-
return [{"text": content}]
606+
content = [{"text": content}]
606607
bedrock_content: List[Dict[str, Any]] = []
607608
for block in _snake_to_camel_keys(content):
608609
if isinstance(block, str):
@@ -641,7 +642,7 @@ def _anthropic_to_bedrock(
641642
bedrock_content.append(
642643
{
643644
"toolResult": {
644-
"toolUseID": block["toolUseId"],
645+
"toolUseId": block["toolUseId"],
645646
"content": _anthropic_to_bedrock(content),
646647
}
647648
}
@@ -651,7 +652,8 @@ def _anthropic_to_bedrock(
651652
bedrock_content.append({"json": block["json"]})
652653
else:
653654
raise ValueError(f"Unsupported content block type:\n{block}")
654-
return bedrock_content
655+
# drop empty text blocks
656+
return [block for block in bedrock_content if block.get("text", True)]
655657

656658

657659
def _bedrock_to_anthropic(content: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
@@ -778,3 +780,39 @@ def _b64str_to_bytes(base64_str: str) -> bytes:
778780

779781
def _bytes_to_b64_str(bytes_: bytes) -> str:
780782
return base64.b64encode(bytes_).decode("utf-8")
783+
784+
785+
def _str_if_single_text_block(
786+
anthropic_content: List[Dict[str, Any]],
787+
) -> Union[str, List[Dict[str, Any]]]:
788+
if len(anthropic_content) == 1 and anthropic_content[0]["type"] == "text":
789+
return anthropic_content[0]["text"]
790+
return anthropic_content
791+
792+
793+
def _upsert_tool_calls_to_bedrock_content(
794+
content: List[Dict[str, Any]], tool_calls: List[ToolCall]
795+
) -> List[Dict[str, Any]]:
796+
existing_tc_blocks = [block for block in content if "toolUse" in block]
797+
for tool_call in tool_calls:
798+
if tool_call["id"] in [
799+
block["toolUse"]["toolUseId"] for block in existing_tc_blocks
800+
]:
801+
tc_block = next(
802+
block
803+
for block in existing_tc_blocks
804+
if block["toolUse"]["toolUseId"] == tool_call["id"]
805+
)
806+
tc_block["toolUse"]["input"] = tool_call["args"]
807+
tc_block["toolUse"]["name"] = tool_call["name"]
808+
else:
809+
content.append(
810+
{
811+
"toolUse": {
812+
"toolUseId": tool_call["id"],
813+
"input": tool_call["args"],
814+
"name": tool_call["name"],
815+
}
816+
}
817+
)
818+
return content

libs/aws/tests/integration_tests/chat_models/test_bedrock_converse_standard.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,4 @@ def chat_model_class(self) -> Type[BaseChatModel]:
1818
def chat_model_params(self) -> dict:
1919
return {
2020
"model_id": "anthropic.claude-3-sonnet-20240229-v1:0",
21-
}
21+
}

0 commit comments

Comments
 (0)