diff --git a/libs/aws/langchain_aws/chat_models/bedrock_converse.py b/libs/aws/langchain_aws/chat_models/bedrock_converse.py index b9c90464..95e1beea 100644 --- a/libs/aws/langchain_aws/chat_models/bedrock_converse.py +++ b/libs/aws/langchain_aws/chat_models/bedrock_converse.py @@ -1,4 +1,5 @@ import base64 +import functools import json import logging import re @@ -61,6 +62,33 @@ logger = logging.getLogger(__name__) _BM = TypeVar("_BM", bound=BaseModel) +MIME_TO_FORMAT = { + # Image formats + "image/png": "png", + "image/jpeg": "jpeg", + "image/gif": "gif", + "image/webp": "webp", + # File formats + "application/pdf": "pdf", + "text/csv": "csv", + "application/msword": "doc", + "application/vnd.openxmlformats-officedocument.wordprocessingml.document": "docx", + "application/vnd.ms-excel": "xls", + "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": "xlsx", + "text/html": "html", + "text/plain": "txt", + "text/markdown": "md", + # Video formats + "video/x-matroska": "mkv", + "video/quicktime": "mov", + "video/mp4": "mp4", + "video/webm": "webm", + "video/x-flv": "flv", + "video/mpeg": "mpeg", + "video/x-ms-wmv": "wmv", + "video/3gpp": "three_gp", +} + _DictOrPydanticClass = Union[Dict[str, Any], Type[_BM], Type] @@ -1200,6 +1228,27 @@ def _parse_stream_event(event: Dict[str, Any]) -> Optional[BaseMessageChunk]: raise ValueError(f"Received unsupported stream event:\n\n{event}") +@functools.cache +def _mime_type_to_format(mime_type: str) -> str: + if "/" not in mime_type: + raise ValueError( + f"Invalid MIME type format: {mime_type}. Expected format: 'type/subtype'" + ) + + if mime_type in MIME_TO_FORMAT: + return MIME_TO_FORMAT[mime_type] + + # Fallback to original method of splitting on "/" for simple cases + all_formats = set(MIME_TO_FORMAT.values()) + format_part = mime_type.split("/")[1] + if format_part in all_formats: + return format_part + + raise ValueError( + f"Unsupported MIME type: {mime_type}. Please refer to the Bedrock Converse API documentation for supported formats." + ) + + def _format_data_content_block(block: dict) -> dict: """Format standard data content block to format expected by Converse API.""" if block["type"] == "image": @@ -1209,7 +1258,7 @@ def _format_data_content_block(block: dict) -> dict: raise ValueError(error_message) formatted_block = { "image": { - "format": block["mimeType"].split("/")[1], + "format": _mime_type_to_format(block["mimeType"]), "source": {"bytes": _b64str_to_bytes(block["data"])}, } } @@ -1224,7 +1273,7 @@ def _format_data_content_block(block: dict) -> dict: raise ValueError(error_message) formatted_block = { "document": { - "format": block["mimeType"].split("/")[1], + "format": _mime_type_to_format(block["mimeType"]), "source": {"bytes": _b64str_to_bytes(block["data"])}, } } @@ -1274,7 +1323,7 @@ def _lc_content_to_bedrock( bedrock_content.append( { "image": { - "format": block["source"]["mediaType"].split("/")[1], + "format": _mime_type_to_format(block["source"]["mediaType"]), "source": { "bytes": _b64str_to_bytes(block["source"]["data"]) }, @@ -1295,7 +1344,7 @@ def _lc_content_to_bedrock( bedrock_content.append( { "video": { - "format": block["source"]["mediaType"].split("/")[1], + "format": _mime_type_to_format(block["source"]["mediaType"]), "source": { "bytes": _b64str_to_bytes(block["source"]["data"]) }, @@ -1306,7 +1355,7 @@ def _lc_content_to_bedrock( bedrock_content.append( { "video": { - "format": block["source"]["mediaType"].split("/")[1], + "format": _mime_type_to_format(block["source"]["mediaType"]), "source": {"s3Location": block["source"]["data"]}, } } diff --git a/libs/aws/tests/unit_tests/chat_models/test_bedrock_converse.py b/libs/aws/tests/unit_tests/chat_models/test_bedrock_converse.py index c66489aa..91118312 100644 --- a/libs/aws/tests/unit_tests/chat_models/test_bedrock_converse.py +++ b/libs/aws/tests/unit_tests/chat_models/test_bedrock_converse.py @@ -1107,6 +1107,95 @@ def test__lc_content_to_bedrock_reasoning_content_signature() -> None: assert expected_system == actual_system +def test__lc_content_to_bedrock_mime_types() -> None: + video_data = base64.b64encode(b"video_test_data").decode("utf-8") + image_data = base64.b64encode(b"image_test_data").decode("utf-8") + file_data = base64.b64encode(b"file_test_data").decode("utf-8") + + # Create content with one of each type + content: List[Union[str, Dict[str, Any]]] = [ + { + "type": "video", + "source": { + "type": "base64", + "mediaType": "video/mp4", + "data": video_data, + }, + }, + { + "type": "image", + "source": { + "type": "base64", + "mediaType": "image/jpeg", + "data": image_data, + }, + }, + { + "type": "file", + "sourceType": "base64", + "mimeType": "application/pdf", + "data": file_data, + "name": "test_document.pdf", + }, + ] + + expected_content = [ + { + "video": { + "format": "mp4", + "source": { + "bytes": base64.b64decode(video_data.encode("utf-8")) + }, + } + }, + { + "image": { + "format": "jpeg", + "source": { + "bytes": base64.b64decode(image_data.encode("utf-8")) + }, + } + }, + { + "document": { + "format": "pdf", + "name": "test_document.pdf", + "source": { + "bytes": base64.b64decode(file_data.encode("utf-8")) + }, + } + }, + ] + + bedrock_content = _lc_content_to_bedrock(content) + assert bedrock_content == expected_content + + +def test__lc_content_to_bedrock_mime_types_invalid() -> None: + with pytest.raises(ValueError, match="Invalid MIME type format"): + _lc_content_to_bedrock([ + { + "type": "image", + "source": { + "type": "base64", + "mediaType": "invalidmimetype", + "data": base64.b64encode(b"test_data").decode("utf-8"), + }, + } + ]) + + with pytest.raises(ValueError, match="Unsupported MIME type"): + _lc_content_to_bedrock([ + { + "type": "file", + "sourceType": "base64", + "mimeType": "application/unknown-format", + "data": base64.b64encode(b"test_data").decode("utf-8"), + "name": "test_document.xyz", + } + ]) + + def test__get_provider() -> None: llm = ChatBedrockConverse( model="anthropic.claude-3-sonnet-20240229-v1:0", region_name="us-west-2" @@ -1572,8 +1661,9 @@ def side_effect(service_name: str, **kwargs: Any) -> mock.Mock: # The streaming should be disabled for models with no streaming support assert chat_model.disable_streaming is True + def test_nova_provider_extraction() -> None: """Test that provider is correctly extracted from Nova model ID when not provided.""" model = ChatBedrockConverse(client=mock.MagicMock(), model="us.amazon.nova-pro-v1:0", region_name="us-west-2") - assert model.provider == "amazon" \ No newline at end of file + assert model.provider == "amazon"