Skip to content

Commit 046efe5

Browse files
authored
aws[minor]: Add ChatModel that uses Bedrock.converse API
2 parents c873bb2 + a7a2d09 commit 046efe5

File tree

11 files changed

+1115
-99
lines changed

11 files changed

+1115
-99
lines changed

libs/aws/langchain_aws/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from langchain_aws.chat_models import BedrockChat, ChatBedrock
1+
from langchain_aws.chat_models import BedrockChat, ChatBedrock, ChatBedrockConverse
22
from langchain_aws.embeddings import BedrockEmbeddings
33
from langchain_aws.graphs import NeptuneAnalyticsGraph, NeptuneGraph
44
from langchain_aws.llms import Bedrock, BedrockLLM, SagemakerEndpoint
@@ -13,6 +13,7 @@
1313
"BedrockLLM",
1414
"BedrockChat",
1515
"ChatBedrock",
16+
"ChatBedrockConverse",
1617
"SagemakerEndpoint",
1718
"AmazonKendraRetriever",
1819
"AmazonKnowledgeBasesRetriever",
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
from langchain_aws.chat_models.bedrock import BedrockChat, ChatBedrock
2+
from langchain_aws.chat_models.bedrock_converse import ChatBedrockConverse
23

3-
__all__ = ["BedrockChat", "ChatBedrock"]
4+
__all__ = ["BedrockChat", "ChatBedrock", "ChatBedrockConverse"]

libs/aws/langchain_aws/chat_models/bedrock.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
3838
from langchain_core.tools import BaseTool
3939

40+
from langchain_aws.chat_models.bedrock_converse import ChatBedrockConverse
4041
from langchain_aws.function_calling import (
4142
ToolsOutputParser,
4243
_lc_tool_calls_to_anthropic_tool_use_blocks,
@@ -387,6 +388,9 @@ class ChatBedrock(BaseChatModel, BedrockBase):
387388
"""A chat model that uses the Bedrock API."""
388389

389390
system_prompt_with_tools: str = ""
391+
beta_use_converse_api: bool = False
392+
"""Use the new Bedrock ``converse`` API which provides a standardized interface to
393+
all Bedrock models. Support still in beta. See ChatBedrockConverse docs for more."""
390394

391395
@property
392396
def _llm_type(self) -> str:
@@ -424,6 +428,11 @@ def _stream(
424428
run_manager: Optional[CallbackManagerForLLMRun] = None,
425429
**kwargs: Any,
426430
) -> Iterator[ChatGenerationChunk]:
431+
if self.beta_use_converse_api:
432+
yield from self._as_converse._stream(
433+
messages, stop=stop, run_manager=run_manager, **kwargs
434+
)
435+
return
427436
provider = self._get_provider()
428437
prompt, system, formatted_messages = None, None, None
429438

@@ -490,6 +499,10 @@ def _generate(
490499
run_manager: Optional[CallbackManagerForLLMRun] = None,
491500
**kwargs: Any,
492501
) -> ChatResult:
502+
if self.beta_use_converse_api:
503+
return self._as_converse._generate(
504+
messages, stop=stop, run_manager=run_manager, **kwargs
505+
)
493506
completion = ""
494507
llm_output: Dict[str, Any] = {}
495508
tool_calls: List[Dict[str, Any]] = []
@@ -608,6 +621,12 @@ def bind_tools(
608621
**kwargs: Any additional parameters to pass to the
609622
:class:`~langchain.runnable.Runnable` constructor.
610623
"""
624+
if self.beta_use_converse_api:
625+
if isinstance(tool_choice, bool):
626+
tool_choice = "any" if tool_choice else None
627+
return self._as_converse.bind_tools(
628+
tools, tool_choice=tool_choice, **kwargs
629+
)
611630
if self._get_provider() == "anthropic":
612631
formatted_tools = [convert_to_anthropic_tool(tool) for tool in tools]
613632

@@ -745,6 +764,10 @@ class AnswerWithJustification(BaseModel):
745764
# }
746765
747766
""" # noqa: E501
767+
if self.beta_use_converse_api:
768+
return self._as_converse.with_structured_output(
769+
schema, include_raw=include_raw, **kwargs
770+
)
748771
if "claude-3" not in self._get_model():
749772
ValueError(
750773
f"Structured output is not supported for model {self._get_model()}"
@@ -769,6 +792,23 @@ class AnswerWithJustification(BaseModel):
769792
else:
770793
return llm | output_parser
771794

795+
@property
796+
def _as_converse(self) -> ChatBedrockConverse:
797+
kwargs = {
798+
k: v
799+
for k, v in (self.model_kwargs or {}).items()
800+
if k in ("stop", "stop_sequences", "max_tokens", "temperature", "top_p")
801+
}
802+
return ChatBedrockConverse(
803+
model=self.model_id,
804+
region_name=self.region_name,
805+
credentials_profile_name=self.credentials_profile_name,
806+
config=self.config,
807+
provider=self.provider or "",
808+
base_url=self.endpoint_url,
809+
**kwargs,
810+
)
811+
772812

773813
@deprecated(since="0.1.0", removal="0.2.0", alternative="ChatBedrock")
774814
class BedrockChat(ChatBedrock):

0 commit comments

Comments
 (0)