Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion libs/aws/langchain_aws/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from langchain_aws.chat_models import BedrockChat, ChatBedrock
from langchain_aws.chat_models import BedrockChat, ChatBedrock, ChatBedrockConverse
from langchain_aws.embeddings import BedrockEmbeddings
from langchain_aws.graphs import NeptuneAnalyticsGraph, NeptuneGraph
from langchain_aws.llms import Bedrock, BedrockLLM, SagemakerEndpoint
Expand All @@ -13,6 +13,7 @@
"BedrockLLM",
"BedrockChat",
"ChatBedrock",
"ChatBedrockConverse",
"SagemakerEndpoint",
"AmazonKendraRetriever",
"AmazonKnowledgeBasesRetriever",
Expand Down
3 changes: 2 additions & 1 deletion libs/aws/langchain_aws/chat_models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from langchain_aws.chat_models.bedrock import BedrockChat, ChatBedrock
from langchain_aws.chat_models.bedrock_converse import ChatBedrockConverse

__all__ = ["BedrockChat", "ChatBedrock"]
__all__ = ["BedrockChat", "ChatBedrock", "ChatBedrockConverse"]
40 changes: 40 additions & 0 deletions libs/aws/langchain_aws/chat_models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
from langchain_core.tools import BaseTool

from langchain_aws.chat_models.bedrock_converse import ChatBedrockConverse
from langchain_aws.function_calling import (
ToolsOutputParser,
_lc_tool_calls_to_anthropic_tool_use_blocks,
Expand Down Expand Up @@ -387,6 +388,9 @@ class ChatBedrock(BaseChatModel, BedrockBase):
"""A chat model that uses the Bedrock API."""

system_prompt_with_tools: str = ""
beta_use_converse_api: bool = False
"""Use the new Bedrock ``converse`` API which provides a standardized interface to
all Bedrock models. Support still in beta. See ChatBedrockConverse docs for more."""

@property
def _llm_type(self) -> str:
Expand Down Expand Up @@ -424,6 +428,11 @@ def _stream(
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
if self.beta_use_converse_api:
yield from self._as_converse._stream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return
provider = self._get_provider()
prompt, system, formatted_messages = None, None, None

Expand Down Expand Up @@ -490,6 +499,10 @@ def _generate(
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
if self.beta_use_converse_api:
return self._as_converse._generate(
messages, stop=stop, run_manager=run_manager, **kwargs
)
completion = ""
llm_output: Dict[str, Any] = {}
tool_calls: List[Dict[str, Any]] = []
Expand Down Expand Up @@ -608,6 +621,12 @@ def bind_tools(
**kwargs: Any additional parameters to pass to the
:class:`~langchain.runnable.Runnable` constructor.
"""
if self.beta_use_converse_api:
if isinstance(tool_choice, bool):
tool_choice = "any" if tool_choice else None
return self._as_converse.bind_tools(
tools, tool_choice=tool_choice, **kwargs
)
if self._get_provider() == "anthropic":
formatted_tools = [convert_to_anthropic_tool(tool) for tool in tools]

Expand Down Expand Up @@ -745,6 +764,10 @@ class AnswerWithJustification(BaseModel):
# }

""" # noqa: E501
if self.beta_use_converse_api:
return self._as_converse.with_structured_output(
schema, include_raw=include_raw, **kwargs
)
if "claude-3" not in self._get_model():
ValueError(
f"Structured output is not supported for model {self._get_model()}"
Expand All @@ -769,6 +792,23 @@ class AnswerWithJustification(BaseModel):
else:
return llm | output_parser

@property
def _as_converse(self) -> ChatBedrockConverse:
kwargs = {
k: v
for k, v in (self.model_kwargs or {}).items()
if k in ("stop", "stop_sequences", "max_tokens", "temperature", "top_p")
}
return ChatBedrockConverse(
model=self.model_id,
region_name=self.region_name,
credentials_profile_name=self.credentials_profile_name,
config=self.config,
provider=self.provider or "",
base_url=self.endpoint_url,
**kwargs,
)


@deprecated(since="0.1.0", removal="0.2.0", alternative="ChatBedrock")
class BedrockChat(ChatBedrock):
Expand Down
Loading