12
12
Mapping ,
13
13
Optional ,
14
14
Tuple ,
15
+ TypedDict ,
15
16
Union ,
16
17
)
17
18
21
22
CallbackManagerForLLMRun ,
22
23
)
23
24
from langchain_core .language_models import LLM , BaseLanguageModel
25
+ from langchain_core .messages import ToolCall
24
26
from langchain_core .outputs import Generation , GenerationChunk , LLMResult
25
27
from langchain_core .pydantic_v1 import Extra , Field , root_validator
26
28
from langchain_core .utils import get_from_dict_or_env
27
29
30
+ from langchain_aws .function_calling import _tools_in_params
28
31
from langchain_aws .utils import (
29
32
enforce_stop_tokens ,
30
33
get_num_tokens_anthropic ,
@@ -81,7 +84,10 @@ def _human_assistant_format(input_text: str) -> str:
81
84
82
85
83
86
def _stream_response_to_generation_chunk (
84
- stream_response : Dict [str , Any ], provider : str , output_key : str , messages_api : bool
87
+ stream_response : Dict [str , Any ],
88
+ provider : str ,
89
+ output_key : str ,
90
+ messages_api : bool ,
85
91
) -> Union [GenerationChunk , None ]:
86
92
"""Convert a stream response to a generation chunk."""
87
93
if messages_api :
@@ -174,6 +180,23 @@ def _combine_generation_info_for_llm_result(
174
180
return {"usage" : total_usage_info , "stop_reason" : stop_reason }
175
181
176
182
183
+ def extract_tool_calls (content : List [dict ]) -> List [ToolCall ]:
184
+ tool_calls = []
185
+ for block in content :
186
+ if block ["type" ] != "tool_use" :
187
+ continue
188
+ tool_calls .append (
189
+ ToolCall (name = block ["name" ], args = block ["input" ], id = block ["id" ])
190
+ )
191
+ return tool_calls
192
+
193
+
194
+ class AnthropicTool (TypedDict ):
195
+ name : str
196
+ description : str
197
+ input_schema : Dict [str , Any ]
198
+
199
+
177
200
class LLMInputOutputAdapter :
178
201
"""Adapter class to prepare the inputs from Langchain to a format
179
202
that LLM model expects.
@@ -197,10 +220,13 @@ def prepare_input(
197
220
prompt : Optional [str ] = None ,
198
221
system : Optional [str ] = None ,
199
222
messages : Optional [List [Dict ]] = None ,
223
+ tools : Optional [List [AnthropicTool ]] = None ,
200
224
) -> Dict [str , Any ]:
201
225
input_body = {** model_kwargs }
202
226
if provider == "anthropic" :
203
227
if messages :
228
+ if tools :
229
+ input_body ["tools" ] = tools
204
230
input_body ["anthropic_version" ] = "bedrock-2023-05-31"
205
231
input_body ["messages" ] = messages
206
232
if system :
@@ -225,16 +251,20 @@ def prepare_input(
225
251
@classmethod
226
252
def prepare_output (cls , provider : str , response : Any ) -> dict :
227
253
text = ""
254
+ tool_calls = []
255
+ response_body = json .loads (response .get ("body" ).read ().decode ())
256
+
228
257
if provider == "anthropic" :
229
- response_body = json .loads (response .get ("body" ).read ().decode ())
230
258
if "completion" in response_body :
231
259
text = response_body .get ("completion" )
232
260
elif "content" in response_body :
233
261
content = response_body .get ("content" )
234
- text = content [0 ].get ("text" )
235
- else :
236
- response_body = json .loads (response .get ("body" ).read ())
262
+ if len (content ) == 1 and content [0 ]["type" ] == "text" :
263
+ text = content [0 ]["text" ]
264
+ elif any (block ["type" ] == "tool_use" for block in content ):
265
+ tool_calls = extract_tool_calls (content )
237
266
267
+ else :
238
268
if provider == "ai21" :
239
269
text = response_body .get ("completions" )[0 ].get ("data" ).get ("text" )
240
270
elif provider == "cohere" :
@@ -251,6 +281,7 @@ def prepare_output(cls, provider: str, response: Any) -> dict:
251
281
completion_tokens = int (headers .get ("x-amzn-bedrock-output-token-count" , 0 ))
252
282
return {
253
283
"text" : text ,
284
+ "tool_calls" : tool_calls ,
254
285
"body" : response_body ,
255
286
"usage" : {
256
287
"prompt_tokens" : prompt_tokens ,
@@ -584,19 +615,32 @@ def _prepare_input_and_invoke(
584
615
stop : Optional [List [str ]] = None ,
585
616
run_manager : Optional [CallbackManagerForLLMRun ] = None ,
586
617
** kwargs : Any ,
587
- ) -> Tuple [str , Dict [str , Any ]]:
618
+ ) -> Tuple [
619
+ str ,
620
+ List [dict ],
621
+ Dict [str , Any ],
622
+ ]:
588
623
_model_kwargs = self .model_kwargs or {}
589
624
590
625
provider = self ._get_provider ()
591
626
params = {** _model_kwargs , ** kwargs }
592
-
593
627
input_body = LLMInputOutputAdapter .prepare_input (
594
628
provider = provider ,
595
629
model_kwargs = params ,
596
630
prompt = prompt ,
597
631
system = system ,
598
632
messages = messages ,
599
633
)
634
+ if "claude-3" in self ._get_model ():
635
+ if _tools_in_params (params ):
636
+ input_body = LLMInputOutputAdapter .prepare_input (
637
+ provider = provider ,
638
+ model_kwargs = params ,
639
+ prompt = prompt ,
640
+ system = system ,
641
+ messages = messages ,
642
+ tools = params ["tools" ],
643
+ )
600
644
body = json .dumps (input_body )
601
645
accept = "application/json"
602
646
contentType = "application/json"
@@ -621,9 +665,13 @@ def _prepare_input_and_invoke(
621
665
try :
622
666
response = self .client .invoke_model (** request_options )
623
667
624
- text , body , usage_info , stop_reason = LLMInputOutputAdapter .prepare_output (
625
- provider , response
626
- ).values ()
668
+ (
669
+ text ,
670
+ tool_calls ,
671
+ body ,
672
+ usage_info ,
673
+ stop_reason ,
674
+ ) = LLMInputOutputAdapter .prepare_output (provider , response ).values ()
627
675
628
676
except Exception as e :
629
677
raise ValueError (f"Error raised by bedrock service: { e } " )
@@ -646,7 +694,7 @@ def _prepare_input_and_invoke(
646
694
** services_trace ,
647
695
)
648
696
649
- return text , llm_output
697
+ return text , tool_calls , llm_output
650
698
651
699
def _get_bedrock_services_signal (self , body : dict ) -> dict :
652
700
"""
@@ -711,6 +759,16 @@ def _prepare_input_and_invoke_stream(
711
759
messages = messages ,
712
760
model_kwargs = params ,
713
761
)
762
+ if "claude-3" in self ._get_model ():
763
+ if _tools_in_params (params ):
764
+ input_body = LLMInputOutputAdapter .prepare_input (
765
+ provider = provider ,
766
+ model_kwargs = params ,
767
+ prompt = prompt ,
768
+ system = system ,
769
+ messages = messages ,
770
+ tools = params ["tools" ],
771
+ )
714
772
body = json .dumps (input_body )
715
773
716
774
request_options = {
@@ -737,7 +795,10 @@ def _prepare_input_and_invoke_stream(
737
795
raise ValueError (f"Error raised by bedrock service: { e } " )
738
796
739
797
for chunk in LLMInputOutputAdapter .prepare_output_stream (
740
- provider , response , stop , True if messages else False
798
+ provider ,
799
+ response ,
800
+ stop ,
801
+ True if messages else False ,
741
802
):
742
803
yield chunk
743
804
# verify and raise callback error if any middleware intervened
@@ -770,13 +831,24 @@ async def _aprepare_input_and_invoke_stream(
770
831
_model_kwargs ["stream" ] = True
771
832
772
833
params = {** _model_kwargs , ** kwargs }
773
- input_body = LLMInputOutputAdapter .prepare_input (
774
- provider = provider ,
775
- prompt = prompt ,
776
- system = system ,
777
- messages = messages ,
778
- model_kwargs = params ,
779
- )
834
+ if "claude-3" in self ._get_model ():
835
+ if _tools_in_params (params ):
836
+ input_body = LLMInputOutputAdapter .prepare_input (
837
+ provider = provider ,
838
+ model_kwargs = params ,
839
+ prompt = prompt ,
840
+ system = system ,
841
+ messages = messages ,
842
+ tools = params ["tools" ],
843
+ )
844
+ else :
845
+ input_body = LLMInputOutputAdapter .prepare_input (
846
+ provider = provider ,
847
+ prompt = prompt ,
848
+ system = system ,
849
+ messages = messages ,
850
+ model_kwargs = params ,
851
+ )
780
852
body = json .dumps (input_body )
781
853
782
854
response = await asyncio .get_running_loop ().run_in_executor (
@@ -790,7 +862,10 @@ async def _aprepare_input_and_invoke_stream(
790
862
)
791
863
792
864
async for chunk in LLMInputOutputAdapter .aprepare_output_stream (
793
- provider , response , stop , True if messages else False
865
+ provider ,
866
+ response ,
867
+ stop ,
868
+ True if messages else False ,
794
869
):
795
870
yield chunk
796
871
if run_manager is not None and asyncio .iscoroutinefunction (
@@ -951,7 +1026,7 @@ def _call(
951
1026
952
1027
return completion
953
1028
954
- text , llm_output = self ._prepare_input_and_invoke (
1029
+ text , tool_calls , llm_output = self ._prepare_input_and_invoke (
955
1030
prompt = prompt , stop = stop , run_manager = run_manager , ** kwargs
956
1031
)
957
1032
if run_manager is not None :
0 commit comments