37
37
from langchain_core .runnables import Runnable , RunnableMap , RunnablePassthrough
38
38
from langchain_core .tools import BaseTool
39
39
40
+ from langchain_aws .chat_models .bedrock_converse import ChatBedrockConverse
40
41
from langchain_aws .function_calling import (
41
42
ToolsOutputParser ,
42
43
_lc_tool_calls_to_anthropic_tool_use_blocks ,
@@ -387,6 +388,9 @@ class ChatBedrock(BaseChatModel, BedrockBase):
387
388
"""A chat model that uses the Bedrock API."""
388
389
389
390
system_prompt_with_tools : str = ""
391
+ beta_use_converse_api : bool = False
392
+ """Use the new Bedrock ``converse`` API which provides a standardized interface to
393
+ all Bedrock models. Support still in beta. See ChatBedrockConverse docs for more."""
390
394
391
395
@property
392
396
def _llm_type (self ) -> str :
@@ -424,6 +428,11 @@ def _stream(
424
428
run_manager : Optional [CallbackManagerForLLMRun ] = None ,
425
429
** kwargs : Any ,
426
430
) -> Iterator [ChatGenerationChunk ]:
431
+ if self .beta_use_converse_api :
432
+ yield from self ._as_converse ._stream (
433
+ messages , stop = stop , run_manager = run_manager , ** kwargs
434
+ )
435
+ return
427
436
provider = self ._get_provider ()
428
437
prompt , system , formatted_messages = None , None , None
429
438
@@ -490,6 +499,10 @@ def _generate(
490
499
run_manager : Optional [CallbackManagerForLLMRun ] = None ,
491
500
** kwargs : Any ,
492
501
) -> ChatResult :
502
+ if self .beta_use_converse_api :
503
+ return self ._as_converse ._generate (
504
+ messages , stop = stop , run_manager = run_manager , ** kwargs
505
+ )
493
506
completion = ""
494
507
llm_output : Dict [str , Any ] = {}
495
508
tool_calls : List [Dict [str , Any ]] = []
@@ -608,6 +621,12 @@ def bind_tools(
608
621
**kwargs: Any additional parameters to pass to the
609
622
:class:`~langchain.runnable.Runnable` constructor.
610
623
"""
624
+ if self .beta_use_converse_api :
625
+ if isinstance (tool_choice , bool ):
626
+ tool_choice = "any" if tool_choice else None
627
+ return self ._as_converse .bind_tools (
628
+ tools , tool_choice = tool_choice , ** kwargs
629
+ )
611
630
if self ._get_provider () == "anthropic" :
612
631
formatted_tools = [convert_to_anthropic_tool (tool ) for tool in tools ]
613
632
@@ -745,6 +764,10 @@ class AnswerWithJustification(BaseModel):
745
764
# }
746
765
747
766
""" # noqa: E501
767
+ if self .beta_use_converse_api :
768
+ return self ._as_converse .with_structured_output (
769
+ schema , include_raw = include_raw , ** kwargs
770
+ )
748
771
if "claude-3" not in self ._get_model ():
749
772
ValueError (
750
773
f"Structured output is not supported for model { self ._get_model ()} "
@@ -769,6 +792,23 @@ class AnswerWithJustification(BaseModel):
769
792
else :
770
793
return llm | output_parser
771
794
795
+ @property
796
+ def _as_converse (self ) -> ChatBedrockConverse :
797
+ kwargs = {
798
+ k : v
799
+ for k , v in (self .model_kwargs or {}).items ()
800
+ if k in ("stop" , "stop_sequences" , "max_tokens" , "temperature" , "top_p" )
801
+ }
802
+ return ChatBedrockConverse (
803
+ model = self .model_id ,
804
+ region_name = self .region_name ,
805
+ credentials_profile_name = self .credentials_profile_name ,
806
+ config = self .config ,
807
+ provider = self .provider or "" ,
808
+ base_url = self .endpoint_url ,
809
+ ** kwargs ,
810
+ )
811
+
772
812
773
813
@deprecated (since = "0.1.0" , removal = "0.2.0" , alternative = "ChatBedrock" )
774
814
class BedrockChat (ChatBedrock ):
0 commit comments