Skip to content
Open
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 67 additions & 16 deletions libs/aws/langchain_aws/chat_models/bedrock_converse.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,9 +467,9 @@ class Joke(BaseModel):
additionalModelResponseFieldPaths.
"""

supports_tool_choice_values: Optional[
Sequence[Literal["auto", "any", "tool"]]
] = None
supports_tool_choice_values: Optional[Sequence[Literal["auto", "any", "tool"]]] = (
None
)
"""Which types of tool_choice values the model supports.

Inferred if not specified. Inferred as ('auto', 'any', 'tool') if a 'claude-3'
Expand Down Expand Up @@ -514,6 +514,70 @@ def create_cache_point(cls, cache_type: str = "default") -> Dict[str, Any]:
"""
return {"cachePoint": {"type": cache_type}}

@classmethod
def create_document(
cls,
name: str,
source: dict[str, Any],
format: Literal["pdf", "csv", "doc", "docx", "xls", "xlsx", "html", "txt", "md"],
context: Optional[str] = None,
enable_citations: Optional[bool] = False,

) -> Dict[str, Any]:
"""Create a document configuration for Bedrock.
Args:
name: The name of the document.
source: The source of the document.
format: The format of the document, or its extension.
context: Info for the model to understand the document for citations.
enable_citations: Whether to enable the Citations API for the document.
Returns:
Dictionary containing a properly formatted to add to message content."""
if re.search(r"[^A-Za-z0-9 \[\]()\-]|\s{2,}", name):
raise ValueError(
"Name must be only alphanumeric characters,"
" whitespace characters (no more than one in a row),"
" hyphens, parentheses, or square brackets."
)

valid_source_types = ["bytes", "content", "s3Location", "text"]
if (
len(source.keys()) > 1
or list(source.keys())[0] not in valid_source_types
):
raise ValueError(
f"The key for source can only be one of the following: {valid_source_types}"
)

if source.get("bytes") and not isinstance(source.get("bytes"), bytes):
raise ValueError(f"Document source with type bytes must be bytes type.")

if source.get("text") and not isinstance(source.get("text"), str):
raise ValueError("Document source with type text must be str type.")

if source.get("s3Location") and not isinstance(
source.get("s3Location").get("uri"), str
):
raise ValueError(
"Document source with type s3Location"
" must have a dictionary with a valid s3 uri as a dict."
)

if source.get("content") and not isinstance(source.get("content", list)):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Passing content source currently fails because isinstance is missing the second argument for type here

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

whoops good catch thanks!

raise ValueError(
"Document source with type content must have a list of document content blocks."
)

document = {"name": name, "source": source, "format": format}

if context:
document["context"] = context

if enable_citations:
document["citations"] = {"enabled": True}

return {"document": document}

@model_validator(mode="before")
@classmethod
def build_extra(cls, values: dict[str, Any]) -> Any:
Expand Down Expand Up @@ -657,19 +721,6 @@ def set_disable_streaming(cls, values: Dict) -> Any:
def validate_environment(self) -> Self:
"""Validate that AWS credentials to and python package exists in environment."""

# Skip creating new client if passed in constructor
if self.client is None:
self.client = create_aws_client(
Comment on lines 660 to 662
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like this was removed unintentionally in the 40241bd merge, let's put it back to fix the tests

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added it back in but I did change my setup so hoping nothing else broke.

region_name=self.region_name,
credentials_profile_name=self.credentials_profile_name,
aws_access_key_id=self.aws_access_key_id,
aws_secret_access_key=self.aws_secret_access_key,
aws_session_token=self.aws_session_token,
endpoint_url=self.endpoint_url,
config=self.config,
service_name="bedrock-runtime",
)

# Create bedrock client for control plane API call
if self.bedrock_client is None:
bedrock_client_cfg = {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -638,3 +638,51 @@ def test_bedrock_pdf_inputs() -> None:
]
)
_ = model.invoke([message])


def test_bedrock_document_usage() -> None:
model = ChatBedrockConverse(
model="us.anthropic.claude-3-5-sonnet-20241022-v2:0", region_name="us-west-2"
)

# Test bytes source type
url = "https://www.w3.org/WAI/ER/tests/xhtml/testfiles/resources/pdf/dummy.pdf"
pdf_bytes = httpx.get(url).content
message = HumanMessage(
[
{"type": "text", "text": "Summarize this document:"},
ChatBedrockConverse.create_document(
"PDFDoc", source={"bytes": pdf_bytes}, format="pdf"
),
]
)

_ = model.invoke([message])

# Test text source type
text = "I am a text document."
message = HumanMessage(
[
{"type": "text", "text": "Summarize this document:"},
ChatBedrockConverse.create_document(
"TextDoc", source={"text": text}, format="txt"
),
]
)
_ = model.invoke([message])

# Test content source type
split_text = [
{"text": "I am the first part of a document."},
{"text": "I am the second part."},
{"text": "I am not sure how I got here."},
]
message = HumanMessage(
[
{"type": "text", "text": "Summarize this document:"},
ChatBedrockConverse.create_document(
"TextDoc", source={"content": split_text}, format="txt"
),
]
)
_ = model.invoke([message])
77 changes: 44 additions & 33 deletions libs/aws/tests/unit_tests/chat_models/test_bedrock_converse.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,8 +555,7 @@ def test__snake_to_camel_keys() -> None:
assert _snake_to_camel_keys(_SNAKE_DICT) == _CAMEL_DICT


def test__format_openai_image_url() -> None:
...
def test__format_openai_image_url() -> None: ...


def test_standard_tracing_params() -> None:
Expand Down Expand Up @@ -1266,27 +1265,27 @@ def test__lc_content_to_bedrock_mime_types_invalid() -> None:

def test__lc_content_to_bedrock_empty_content() -> None:
content: List[Union[str, Dict[str, Any]]] = []

bedrock_content = _lc_content_to_bedrock(content)

assert len(bedrock_content) > 0
assert bedrock_content[0]["text"] == "."


def test__lc_content_to_bedrock_whitespace_only_content() -> None:
content = " \n \t "

bedrock_content = _lc_content_to_bedrock(content)

assert len(bedrock_content) > 0
assert bedrock_content[0]["text"] == "."


def test__lc_content_to_bedrock_empty_string_content() -> None:
content = ""

bedrock_content = _lc_content_to_bedrock(content)

assert len(bedrock_content) > 0
assert bedrock_content[0]["text"] == "."

Expand All @@ -1295,33 +1294,29 @@ def test__lc_content_to_bedrock_mixed_empty_content() -> None:
content: List[Union[str, Dict[str, Any]]] = [
{"type": "text", "text": ""},
{"type": "text", "text": " "},
{"type": "text", "text": ""}
{"type": "text", "text": ""},
]

bedrock_content = _lc_content_to_bedrock(content)

assert len(bedrock_content) > 0
assert bedrock_content[0]["text"] == "."


def test__lc_content_to_bedrock_empty_text_block() -> None:
content: List[Union[str, Dict[str, Any]]] = [
{"type": "text", "text": ""}
]

content: List[Union[str, Dict[str, Any]]] = [{"type": "text", "text": ""}]

bedrock_content = _lc_content_to_bedrock(content)

assert len(bedrock_content) > 0
assert bedrock_content[0]["text"] == "."


def test__lc_content_to_bedrock_whitespace_text_block() -> None:
content: List[Union[str, Dict[str, Any]]] = [
{"type": "text", "text": " \n "}
]

content: List[Union[str, Dict[str, Any]]] = [{"type": "text", "text": " \n "}]

bedrock_content = _lc_content_to_bedrock(content)

assert len(bedrock_content) > 0
assert bedrock_content[0]["text"] == "."

Expand All @@ -1330,9 +1325,9 @@ def test__lc_content_to_bedrock_mixed_valid_and_empty_content() -> None:
content: List[Union[str, Dict[str, Any]]] = [
{"type": "text", "text": "Valid text"},
{"type": "text", "text": ""},
{"type": "text", "text": " "}
{"type": "text", "text": " "},
]

bedrock_content = _lc_content_to_bedrock(content)

assert len(bedrock_content) == 3
Expand All @@ -1350,21 +1345,21 @@ def test__lc_content_to_bedrock_mixed_types_with_empty_content() -> None:
"input": {"arg1": "val1"},
"name": "tool1",
},
{"type": "text", "text": " "}
{"type": "text", "text": " "},
]

expected = [
{'text': 'Valid text'},
{"text": "Valid text"},
{
'toolUse': {
'toolUseId': 'tool_call1',
'input': {'arg1': 'val1'},
'name': 'tool1'
"toolUse": {
"toolUseId": "tool_call1",
"input": {"arg1": "val1"},
"name": "tool1",
}
},
{'text': '.'}
{"text": "."},
]

bedrock_content = _lc_content_to_bedrock(content)

assert len(bedrock_content) == 3
Expand Down Expand Up @@ -1497,6 +1492,22 @@ def test_create_cache_point() -> None:
assert cache_point["cachePoint"]["type"] == "default"


def test_create_document() -> None:
"""Test creating a document."""
document = ChatBedrockConverse.create_document(
name="MyDoc", source={"text": "Cite me"}, format="txt", enable_citations=True
)
expected_doc = {
"document": {
"name": "MyDoc",
"source": {"text": "Cite me"},
"format": "txt",
"citations": {"enabled": True},
}
}
assert document == expected_doc


def test_anthropic_tool_with_cache_point() -> None:
"""Test convert_to_anthropic_tool with cache point"""
# Test with cache point
Expand Down Expand Up @@ -1573,9 +1584,9 @@ def test_model_kwargs() -> None:
assert llm.temperature is None


def _create_mock_llm_guard_last_turn_only() -> (
Tuple[ChatBedrockConverse, mock.MagicMock]
):
def _create_mock_llm_guard_last_turn_only() -> Tuple[
ChatBedrockConverse, mock.MagicMock
]:
"""Utility to create an LLM with guard_last_turn_only=True and a mocked client."""
mocked_client = mock.MagicMock()
llm = ChatBedrockConverse(
Expand Down
Loading