Skip to content

Commit 99409d8

Browse files
authored
Merge pull request #70 from laithalsaadoon/feat/claude-3-tool-calling
Feat: claude 3 tool calling
2 parents 1bdbe20 + 15c93ed commit 99409d8

File tree

8 files changed

+972
-78
lines changed

8 files changed

+972
-78
lines changed

libs/aws/langchain_aws/chat_models/bedrock.py

Lines changed: 295 additions & 41 deletions
Large diffs are not rendered by default.

libs/aws/langchain_aws/function_calling.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,16 @@
88
Dict,
99
List,
1010
Literal,
11+
Optional,
1112
Type,
1213
Union,
14+
cast,
1315
)
1416

17+
from langchain_core.messages import ToolCall
18+
from langchain_core.output_parsers import BaseGenerationOutputParser
19+
from langchain_core.outputs import ChatGeneration, Generation
20+
from langchain_core.prompts.chat import AIMessage
1521
from langchain_core.pydantic_v1 import BaseModel
1622
from langchain_core.tools import BaseTool
1723
from langchain_core.utils.function_calling import convert_to_openai_tool
@@ -63,6 +69,35 @@ class AnthropicTool(TypedDict):
6369
input_schema: Dict[str, Any]
6470

6571

72+
def _tools_in_params(params: dict) -> bool:
73+
return "tools" in params or (
74+
"extra_body" in params and params["extra_body"].get("tools")
75+
)
76+
77+
78+
class _AnthropicToolUse(TypedDict):
79+
type: Literal["tool_use"]
80+
name: str
81+
input: dict
82+
id: str
83+
84+
85+
def _lc_tool_calls_to_anthropic_tool_use_blocks(
86+
tool_calls: List[ToolCall],
87+
) -> List[_AnthropicToolUse]:
88+
blocks = []
89+
for tool_call in tool_calls:
90+
blocks.append(
91+
_AnthropicToolUse(
92+
type="tool_use",
93+
name=tool_call["name"],
94+
input=tool_call["args"],
95+
id=cast(str, tool_call["id"]),
96+
)
97+
)
98+
return blocks
99+
100+
66101
def _get_type(parameter: Dict[str, Any]) -> str:
67102
if "type" in parameter:
68103
return parameter["type"]
@@ -122,6 +157,54 @@ class ToolDescription(TypedDict):
122157
function: FunctionDescription
123158

124159

160+
class ToolsOutputParser(BaseGenerationOutputParser):
161+
first_tool_only: bool = False
162+
args_only: bool = False
163+
pydantic_schemas: Optional[List[Type[BaseModel]]] = None
164+
165+
class Config:
166+
extra = "forbid"
167+
168+
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
169+
"""Parse a list of candidate model Generations into a specific format.
170+
171+
Args:
172+
result: A list of Generations to be parsed. The Generations are assumed
173+
to be different candidate outputs for a single model input.
174+
175+
Returns:
176+
Structured output.
177+
"""
178+
if not result or not isinstance(result[0], ChatGeneration):
179+
return None if self.first_tool_only else []
180+
message = result[0].message
181+
if len(message.content) > 0:
182+
tool_calls: List = []
183+
else:
184+
content = cast(AIMessage, message)
185+
_tool_calls = [dict(tc) for tc in content.tool_calls]
186+
# Map tool call id to index
187+
id_to_index = {block["id"]: i for i, block in enumerate(_tool_calls)}
188+
tool_calls = [{**tc, "index": id_to_index[tc["id"]]} for tc in _tool_calls]
189+
if self.pydantic_schemas:
190+
tool_calls = [self._pydantic_parse(tc) for tc in tool_calls]
191+
elif self.args_only:
192+
tool_calls = [tc["args"] for tc in tool_calls]
193+
else:
194+
pass
195+
196+
if self.first_tool_only:
197+
return tool_calls[0] if tool_calls else None
198+
else:
199+
return [tool_call for tool_call in tool_calls]
200+
201+
def _pydantic_parse(self, tool_call: dict) -> BaseModel:
202+
cls_ = {schema.__name__: schema for schema in self.pydantic_schemas or []}[
203+
tool_call["name"]
204+
]
205+
return cls_(**tool_call["args"])
206+
207+
125208
def convert_to_anthropic_tool(
126209
tool: Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool],
127210
) -> AnthropicTool:

libs/aws/langchain_aws/llms/bedrock.py

Lines changed: 96 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
Mapping,
1313
Optional,
1414
Tuple,
15+
TypedDict,
1516
Union,
1617
)
1718

@@ -21,10 +22,12 @@
2122
CallbackManagerForLLMRun,
2223
)
2324
from langchain_core.language_models import LLM, BaseLanguageModel
25+
from langchain_core.messages import ToolCall
2426
from langchain_core.outputs import Generation, GenerationChunk, LLMResult
2527
from langchain_core.pydantic_v1 import Extra, Field, root_validator
2628
from langchain_core.utils import get_from_dict_or_env
2729

30+
from langchain_aws.function_calling import _tools_in_params
2831
from langchain_aws.utils import (
2932
enforce_stop_tokens,
3033
get_num_tokens_anthropic,
@@ -81,7 +84,10 @@ def _human_assistant_format(input_text: str) -> str:
8184

8285

8386
def _stream_response_to_generation_chunk(
84-
stream_response: Dict[str, Any], provider: str, output_key: str, messages_api: bool
87+
stream_response: Dict[str, Any],
88+
provider: str,
89+
output_key: str,
90+
messages_api: bool,
8591
) -> Union[GenerationChunk, None]:
8692
"""Convert a stream response to a generation chunk."""
8793
if messages_api:
@@ -174,6 +180,23 @@ def _combine_generation_info_for_llm_result(
174180
return {"usage": total_usage_info, "stop_reason": stop_reason}
175181

176182

183+
def extract_tool_calls(content: List[dict]) -> List[ToolCall]:
184+
tool_calls = []
185+
for block in content:
186+
if block["type"] != "tool_use":
187+
continue
188+
tool_calls.append(
189+
ToolCall(name=block["name"], args=block["input"], id=block["id"])
190+
)
191+
return tool_calls
192+
193+
194+
class AnthropicTool(TypedDict):
195+
name: str
196+
description: str
197+
input_schema: Dict[str, Any]
198+
199+
177200
class LLMInputOutputAdapter:
178201
"""Adapter class to prepare the inputs from Langchain to a format
179202
that LLM model expects.
@@ -197,10 +220,13 @@ def prepare_input(
197220
prompt: Optional[str] = None,
198221
system: Optional[str] = None,
199222
messages: Optional[List[Dict]] = None,
223+
tools: Optional[List[AnthropicTool]] = None,
200224
) -> Dict[str, Any]:
201225
input_body = {**model_kwargs}
202226
if provider == "anthropic":
203227
if messages:
228+
if tools:
229+
input_body["tools"] = tools
204230
input_body["anthropic_version"] = "bedrock-2023-05-31"
205231
input_body["messages"] = messages
206232
if system:
@@ -225,16 +251,20 @@ def prepare_input(
225251
@classmethod
226252
def prepare_output(cls, provider: str, response: Any) -> dict:
227253
text = ""
254+
tool_calls = []
255+
response_body = json.loads(response.get("body").read().decode())
256+
228257
if provider == "anthropic":
229-
response_body = json.loads(response.get("body").read().decode())
230258
if "completion" in response_body:
231259
text = response_body.get("completion")
232260
elif "content" in response_body:
233261
content = response_body.get("content")
234-
text = content[0].get("text")
235-
else:
236-
response_body = json.loads(response.get("body").read())
262+
if len(content) == 1 and content[0]["type"] == "text":
263+
text = content[0]["text"]
264+
elif any(block["type"] == "tool_use" for block in content):
265+
tool_calls = extract_tool_calls(content)
237266

267+
else:
238268
if provider == "ai21":
239269
text = response_body.get("completions")[0].get("data").get("text")
240270
elif provider == "cohere":
@@ -251,6 +281,7 @@ def prepare_output(cls, provider: str, response: Any) -> dict:
251281
completion_tokens = int(headers.get("x-amzn-bedrock-output-token-count", 0))
252282
return {
253283
"text": text,
284+
"tool_calls": tool_calls,
254285
"body": response_body,
255286
"usage": {
256287
"prompt_tokens": prompt_tokens,
@@ -584,19 +615,32 @@ def _prepare_input_and_invoke(
584615
stop: Optional[List[str]] = None,
585616
run_manager: Optional[CallbackManagerForLLMRun] = None,
586617
**kwargs: Any,
587-
) -> Tuple[str, Dict[str, Any]]:
618+
) -> Tuple[
619+
str,
620+
List[dict],
621+
Dict[str, Any],
622+
]:
588623
_model_kwargs = self.model_kwargs or {}
589624

590625
provider = self._get_provider()
591626
params = {**_model_kwargs, **kwargs}
592-
593627
input_body = LLMInputOutputAdapter.prepare_input(
594628
provider=provider,
595629
model_kwargs=params,
596630
prompt=prompt,
597631
system=system,
598632
messages=messages,
599633
)
634+
if "claude-3" in self._get_model():
635+
if _tools_in_params(params):
636+
input_body = LLMInputOutputAdapter.prepare_input(
637+
provider=provider,
638+
model_kwargs=params,
639+
prompt=prompt,
640+
system=system,
641+
messages=messages,
642+
tools=params["tools"],
643+
)
600644
body = json.dumps(input_body)
601645
accept = "application/json"
602646
contentType = "application/json"
@@ -621,9 +665,13 @@ def _prepare_input_and_invoke(
621665
try:
622666
response = self.client.invoke_model(**request_options)
623667

624-
text, body, usage_info, stop_reason = LLMInputOutputAdapter.prepare_output(
625-
provider, response
626-
).values()
668+
(
669+
text,
670+
tool_calls,
671+
body,
672+
usage_info,
673+
stop_reason,
674+
) = LLMInputOutputAdapter.prepare_output(provider, response).values()
627675

628676
except Exception as e:
629677
raise ValueError(f"Error raised by bedrock service: {e}")
@@ -646,7 +694,7 @@ def _prepare_input_and_invoke(
646694
**services_trace,
647695
)
648696

649-
return text, llm_output
697+
return text, tool_calls, llm_output
650698

651699
def _get_bedrock_services_signal(self, body: dict) -> dict:
652700
"""
@@ -711,6 +759,16 @@ def _prepare_input_and_invoke_stream(
711759
messages=messages,
712760
model_kwargs=params,
713761
)
762+
if "claude-3" in self._get_model():
763+
if _tools_in_params(params):
764+
input_body = LLMInputOutputAdapter.prepare_input(
765+
provider=provider,
766+
model_kwargs=params,
767+
prompt=prompt,
768+
system=system,
769+
messages=messages,
770+
tools=params["tools"],
771+
)
714772
body = json.dumps(input_body)
715773

716774
request_options = {
@@ -737,7 +795,10 @@ def _prepare_input_and_invoke_stream(
737795
raise ValueError(f"Error raised by bedrock service: {e}")
738796

739797
for chunk in LLMInputOutputAdapter.prepare_output_stream(
740-
provider, response, stop, True if messages else False
798+
provider,
799+
response,
800+
stop,
801+
True if messages else False,
741802
):
742803
yield chunk
743804
# verify and raise callback error if any middleware intervened
@@ -770,13 +831,24 @@ async def _aprepare_input_and_invoke_stream(
770831
_model_kwargs["stream"] = True
771832

772833
params = {**_model_kwargs, **kwargs}
773-
input_body = LLMInputOutputAdapter.prepare_input(
774-
provider=provider,
775-
prompt=prompt,
776-
system=system,
777-
messages=messages,
778-
model_kwargs=params,
779-
)
834+
if "claude-3" in self._get_model():
835+
if _tools_in_params(params):
836+
input_body = LLMInputOutputAdapter.prepare_input(
837+
provider=provider,
838+
model_kwargs=params,
839+
prompt=prompt,
840+
system=system,
841+
messages=messages,
842+
tools=params["tools"],
843+
)
844+
else:
845+
input_body = LLMInputOutputAdapter.prepare_input(
846+
provider=provider,
847+
prompt=prompt,
848+
system=system,
849+
messages=messages,
850+
model_kwargs=params,
851+
)
780852
body = json.dumps(input_body)
781853

782854
response = await asyncio.get_running_loop().run_in_executor(
@@ -790,7 +862,10 @@ async def _aprepare_input_and_invoke_stream(
790862
)
791863

792864
async for chunk in LLMInputOutputAdapter.aprepare_output_stream(
793-
provider, response, stop, True if messages else False
865+
provider,
866+
response,
867+
stop,
868+
True if messages else False,
794869
):
795870
yield chunk
796871
if run_manager is not None and asyncio.iscoroutinefunction(
@@ -951,7 +1026,7 @@ def _call(
9511026

9521027
return completion
9531028

954-
text, llm_output = self._prepare_input_and_invoke(
1029+
text, tool_calls, llm_output = self._prepare_input_and_invoke(
9551030
prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
9561031
)
9571032
if run_manager is not None:

0 commit comments

Comments
 (0)