Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 47 additions & 5 deletions libs/aws/langchain_aws/chat_models/bedrock_converse.py
Original file line number Diff line number Diff line change
Expand Up @@ -1200,6 +1200,48 @@ def _parse_stream_event(event: Dict[str, Any]) -> Optional[BaseMessageChunk]:
raise ValueError(f"Received unsupported stream event:\n\n{event}")


def _mime_type_to_format(mime_type: str) -> str:
mime_to_format = {
# Image formats
"image/png": "png",
"image/jpeg": "jpeg",
"image/gif": "gif",
"image/webp": "webp",
# File formats
"application/pdf": "pdf",
"text/csv": "csv",
"application/msword": "doc",
"application/vnd.openxmlformats-officedocument.wordprocessingml.document": "docx",
"application/vnd.ms-excel": "xls",
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": "xlsx",
"text/html": "html",
"text/plain": "txt",
"text/markdown": "md",
# Video formats
"video/x-matroska": "mkv",
"video/quicktime": "mov",
"video/mp4": "mp4",
"video/webm": "webm",
"video/x-flv": "flv",
"video/mpeg": "mpeg",
"video/x-ms-wmv": "wmv",
"video/3gpp": "three_gp",
}

if mime_type in mime_to_format:
return mime_to_format[mime_type]

# Fallback to original method of splitting on "/" for simple cases
all_formats = set(mime_to_format.values())
format_part = mime_type.split("/")[1]
if format_part in all_formats:
return format_part

raise ValueError(
f"Unsupported MIME type: {mime_type}. Please refer to the Bedrock Converse API documentation for supported formats."
)


def _format_data_content_block(block: dict) -> dict:
"""Format standard data content block to format expected by Converse API."""
if block["type"] == "image":
Expand All @@ -1209,7 +1251,7 @@ def _format_data_content_block(block: dict) -> dict:
raise ValueError(error_message)
formatted_block = {
"image": {
"format": block["mimeType"].split("/")[1],
"format": _mime_type_to_format(block["mimeType"]),
"source": {"bytes": _b64str_to_bytes(block["data"])},
}
}
Expand All @@ -1224,7 +1266,7 @@ def _format_data_content_block(block: dict) -> dict:
raise ValueError(error_message)
formatted_block = {
"document": {
"format": block["mimeType"].split("/")[1],
"format": _mime_type_to_format(block["mimeType"]),
"source": {"bytes": _b64str_to_bytes(block["data"])},
}
}
Expand Down Expand Up @@ -1274,7 +1316,7 @@ def _lc_content_to_bedrock(
bedrock_content.append(
{
"image": {
"format": block["source"]["mediaType"].split("/")[1],
"format": _mime_type_to_format(block["source"]["mediaType"]),
"source": {
"bytes": _b64str_to_bytes(block["source"]["data"])
},
Expand All @@ -1295,7 +1337,7 @@ def _lc_content_to_bedrock(
bedrock_content.append(
{
"video": {
"format": block["source"]["mediaType"].split("/")[1],
"format": _mime_type_to_format(block["source"]["mediaType"]),
"source": {
"bytes": _b64str_to_bytes(block["source"]["data"])
},
Expand All @@ -1306,7 +1348,7 @@ def _lc_content_to_bedrock(
bedrock_content.append(
{
"video": {
"format": block["source"]["mediaType"].split("/")[1],
"format": _mime_type_to_format(block["source"]["mediaType"]),
"source": {"s3Location": block["source"]["data"]},
}
}
Expand Down
67 changes: 66 additions & 1 deletion libs/aws/tests/unit_tests/chat_models/test_bedrock_converse.py
Original file line number Diff line number Diff line change
Expand Up @@ -1107,6 +1107,70 @@ def test__lc_content_to_bedrock_reasoning_content_signature() -> None:
assert expected_system == actual_system


def test__lc_content_to_bedrock_mime_types() -> None:
video_data = base64.b64encode(b"video_test_data").decode("utf-8")
image_data = base64.b64encode(b"image_test_data").decode("utf-8")
file_data = base64.b64encode(b"file_test_data").decode("utf-8")

# Create content with one of each type
content: List[Union[str, Dict[str, Any]]] = [
{
"type": "video",
"source": {
"type": "base64",
"mediaType": "video/mp4",
"data": video_data,
},
},
{
"type": "image",
"source": {
"type": "base64",
"mediaType": "image/jpeg",
"data": image_data,
},
},
{
"type": "file",
"sourceType": "base64",
"mimeType": "application/pdf",
"data": file_data,
"name": "test_document.pdf",
},
]

expected_content = [
{
"video": {
"format": "mp4",
"source": {
"bytes": base64.b64decode(video_data.encode("utf-8"))
},
}
},
{
"image": {
"format": "jpeg",
"source": {
"bytes": base64.b64decode(image_data.encode("utf-8"))
},
}
},
{
"document": {
"format": "pdf",
"name": "test_document.pdf",
"source": {
"bytes": base64.b64decode(file_data.encode("utf-8"))
},
}
},
]

bedrock_content = _lc_content_to_bedrock(content)
assert bedrock_content == expected_content


def test__get_provider() -> None:
llm = ChatBedrockConverse(
model="anthropic.claude-3-sonnet-20240229-v1:0", region_name="us-west-2"
Expand Down Expand Up @@ -1572,8 +1636,9 @@ def side_effect(service_name: str, **kwargs: Any) -> mock.Mock:

# The streaming should be disabled for models with no streaming support
assert chat_model.disable_streaming is True


def test_nova_provider_extraction() -> None:
"""Test that provider is correctly extracted from Nova model ID when not provided."""
model = ChatBedrockConverse(client=mock.MagicMock(), model="us.amazon.nova-pro-v1:0", region_name="us-west-2")
assert model.provider == "amazon"
assert model.provider == "amazon"