1212 Mapping ,
1313 Optional ,
1414 Tuple ,
15+ TypedDict ,
1516 Union ,
1617)
1718
2122 CallbackManagerForLLMRun ,
2223)
2324from langchain_core .language_models import LLM , BaseLanguageModel
25+ from langchain_core .messages import ToolCall
2426from langchain_core .outputs import Generation , GenerationChunk , LLMResult
2527from langchain_core .pydantic_v1 import Extra , Field , root_validator
2628from langchain_core .utils import get_from_dict_or_env
2729
30+ from langchain_aws .function_calling import _tools_in_params
2831from langchain_aws .utils import (
2932 enforce_stop_tokens ,
3033 get_num_tokens_anthropic ,
@@ -81,7 +84,10 @@ def _human_assistant_format(input_text: str) -> str:
8184
8285
8386def _stream_response_to_generation_chunk (
84- stream_response : Dict [str , Any ], provider : str , output_key : str , messages_api : bool
87+ stream_response : Dict [str , Any ],
88+ provider : str ,
89+ output_key : str ,
90+ messages_api : bool ,
8591) -> Union [GenerationChunk , None ]:
8692 """Convert a stream response to a generation chunk."""
8793 if messages_api :
@@ -174,6 +180,23 @@ def _combine_generation_info_for_llm_result(
174180 return {"usage" : total_usage_info , "stop_reason" : stop_reason }
175181
176182
183+ def extract_tool_calls (content : List [dict ]) -> List [ToolCall ]:
184+ tool_calls = []
185+ for block in content :
186+ if block ["type" ] != "tool_use" :
187+ continue
188+ tool_calls .append (
189+ ToolCall (name = block ["name" ], args = block ["input" ], id = block ["id" ])
190+ )
191+ return tool_calls
192+
193+
194+ class AnthropicTool (TypedDict ):
195+ name : str
196+ description : str
197+ input_schema : Dict [str , Any ]
198+
199+
177200class LLMInputOutputAdapter :
178201 """Adapter class to prepare the inputs from Langchain to a format
179202 that LLM model expects.
@@ -197,10 +220,13 @@ def prepare_input(
197220 prompt : Optional [str ] = None ,
198221 system : Optional [str ] = None ,
199222 messages : Optional [List [Dict ]] = None ,
223+ tools : Optional [List [AnthropicTool ]] = None ,
200224 ) -> Dict [str , Any ]:
201225 input_body = {** model_kwargs }
202226 if provider == "anthropic" :
203227 if messages :
228+ if tools :
229+ input_body ["tools" ] = tools
204230 input_body ["anthropic_version" ] = "bedrock-2023-05-31"
205231 input_body ["messages" ] = messages
206232 if system :
@@ -225,16 +251,20 @@ def prepare_input(
225251 @classmethod
226252 def prepare_output (cls , provider : str , response : Any ) -> dict :
227253 text = ""
254+ tool_calls = []
255+ response_body = json .loads (response .get ("body" ).read ().decode ())
256+
228257 if provider == "anthropic" :
229- response_body = json .loads (response .get ("body" ).read ().decode ())
230258 if "completion" in response_body :
231259 text = response_body .get ("completion" )
232260 elif "content" in response_body :
233261 content = response_body .get ("content" )
234- text = content [0 ].get ("text" )
235- else :
236- response_body = json .loads (response .get ("body" ).read ())
262+ if len (content ) == 1 and content [0 ]["type" ] == "text" :
263+ text = content [0 ]["text" ]
264+ elif any (block ["type" ] == "tool_use" for block in content ):
265+ tool_calls = extract_tool_calls (content )
237266
267+ else :
238268 if provider == "ai21" :
239269 text = response_body .get ("completions" )[0 ].get ("data" ).get ("text" )
240270 elif provider == "cohere" :
@@ -251,6 +281,7 @@ def prepare_output(cls, provider: str, response: Any) -> dict:
251281 completion_tokens = int (headers .get ("x-amzn-bedrock-output-token-count" , 0 ))
252282 return {
253283 "text" : text ,
284+ "tool_calls" : tool_calls ,
254285 "body" : response_body ,
255286 "usage" : {
256287 "prompt_tokens" : prompt_tokens ,
@@ -584,19 +615,32 @@ def _prepare_input_and_invoke(
584615 stop : Optional [List [str ]] = None ,
585616 run_manager : Optional [CallbackManagerForLLMRun ] = None ,
586617 ** kwargs : Any ,
587- ) -> Tuple [str , Dict [str , Any ]]:
618+ ) -> Tuple [
619+ str ,
620+ List [dict ],
621+ Dict [str , Any ],
622+ ]:
588623 _model_kwargs = self .model_kwargs or {}
589624
590625 provider = self ._get_provider ()
591626 params = {** _model_kwargs , ** kwargs }
592-
593627 input_body = LLMInputOutputAdapter .prepare_input (
594628 provider = provider ,
595629 model_kwargs = params ,
596630 prompt = prompt ,
597631 system = system ,
598632 messages = messages ,
599633 )
634+ if "claude-3" in self ._get_model ():
635+ if _tools_in_params (params ):
636+ input_body = LLMInputOutputAdapter .prepare_input (
637+ provider = provider ,
638+ model_kwargs = params ,
639+ prompt = prompt ,
640+ system = system ,
641+ messages = messages ,
642+ tools = params ["tools" ],
643+ )
600644 body = json .dumps (input_body )
601645 accept = "application/json"
602646 contentType = "application/json"
@@ -621,9 +665,13 @@ def _prepare_input_and_invoke(
621665 try :
622666 response = self .client .invoke_model (** request_options )
623667
624- text , body , usage_info , stop_reason = LLMInputOutputAdapter .prepare_output (
625- provider , response
626- ).values ()
668+ (
669+ text ,
670+ tool_calls ,
671+ body ,
672+ usage_info ,
673+ stop_reason ,
674+ ) = LLMInputOutputAdapter .prepare_output (provider , response ).values ()
627675
628676 except Exception as e :
629677 raise ValueError (f"Error raised by bedrock service: { e } " )
@@ -646,7 +694,7 @@ def _prepare_input_and_invoke(
646694 ** services_trace ,
647695 )
648696
649- return text , llm_output
697+ return text , tool_calls , llm_output
650698
651699 def _get_bedrock_services_signal (self , body : dict ) -> dict :
652700 """
@@ -711,6 +759,16 @@ def _prepare_input_and_invoke_stream(
711759 messages = messages ,
712760 model_kwargs = params ,
713761 )
762+ if "claude-3" in self ._get_model ():
763+ if _tools_in_params (params ):
764+ input_body = LLMInputOutputAdapter .prepare_input (
765+ provider = provider ,
766+ model_kwargs = params ,
767+ prompt = prompt ,
768+ system = system ,
769+ messages = messages ,
770+ tools = params ["tools" ],
771+ )
714772 body = json .dumps (input_body )
715773
716774 request_options = {
@@ -737,7 +795,10 @@ def _prepare_input_and_invoke_stream(
737795 raise ValueError (f"Error raised by bedrock service: { e } " )
738796
739797 for chunk in LLMInputOutputAdapter .prepare_output_stream (
740- provider , response , stop , True if messages else False
798+ provider ,
799+ response ,
800+ stop ,
801+ True if messages else False ,
741802 ):
742803 yield chunk
743804 # verify and raise callback error if any middleware intervened
@@ -770,13 +831,24 @@ async def _aprepare_input_and_invoke_stream(
770831 _model_kwargs ["stream" ] = True
771832
772833 params = {** _model_kwargs , ** kwargs }
773- input_body = LLMInputOutputAdapter .prepare_input (
774- provider = provider ,
775- prompt = prompt ,
776- system = system ,
777- messages = messages ,
778- model_kwargs = params ,
779- )
834+ if "claude-3" in self ._get_model ():
835+ if _tools_in_params (params ):
836+ input_body = LLMInputOutputAdapter .prepare_input (
837+ provider = provider ,
838+ model_kwargs = params ,
839+ prompt = prompt ,
840+ system = system ,
841+ messages = messages ,
842+ tools = params ["tools" ],
843+ )
844+ else :
845+ input_body = LLMInputOutputAdapter .prepare_input (
846+ provider = provider ,
847+ prompt = prompt ,
848+ system = system ,
849+ messages = messages ,
850+ model_kwargs = params ,
851+ )
780852 body = json .dumps (input_body )
781853
782854 response = await asyncio .get_running_loop ().run_in_executor (
@@ -790,7 +862,10 @@ async def _aprepare_input_and_invoke_stream(
790862 )
791863
792864 async for chunk in LLMInputOutputAdapter .aprepare_output_stream (
793- provider , response , stop , True if messages else False
865+ provider ,
866+ response ,
867+ stop ,
868+ True if messages else False ,
794869 ):
795870 yield chunk
796871 if run_manager is not None and asyncio .iscoroutinefunction (
@@ -951,7 +1026,7 @@ def _call(
9511026
9521027 return completion
9531028
954- text , llm_output = self ._prepare_input_and_invoke (
1029+ text , tool_calls , llm_output = self ._prepare_input_and_invoke (
9551030 prompt = prompt , stop = stop , run_manager = run_manager , ** kwargs
9561031 )
9571032 if run_manager is not None :
0 commit comments