diff --git a/libs/aws/langchain_aws/llms/bedrock.py b/libs/aws/langchain_aws/llms/bedrock.py index 2fcf346e..92e1803a 100644 --- a/libs/aws/langchain_aws/llms/bedrock.py +++ b/libs/aws/langchain_aws/llms/bedrock.py @@ -193,15 +193,19 @@ def _get_invocation_metrics_chunk(chunk: Dict[str, Any]) -> GenerationChunk: return GenerationChunk(text="", generation_info=generation_info) -def extract_tool_calls(content: List[dict]) -> List[ToolCall]: +def extract_tool_calls_and_text(content: List[dict]) -> Tuple[List[ToolCall], str]: + text = "" tool_calls = [] for block in content: + if block["type"] == "text": + text = block["text"] + continue if block["type"] != "tool_use": continue tool_calls.append( ToolCall(name=block["name"], args=block["input"], id=block["id"]) ) - return tool_calls + return tool_calls, text class AnthropicTool(TypedDict): @@ -266,7 +270,6 @@ def prepare_output(cls, provider: str, response: Any) -> dict: text = "" tool_calls = [] response_body = json.loads(response.get("body").read().decode()) - if provider == "anthropic": if "completion" in response_body: text = response_body.get("completion") @@ -275,7 +278,7 @@ def prepare_output(cls, provider: str, response: Any) -> dict: if len(content) == 1 and content[0]["type"] == "text": text = content[0]["text"] elif any(block["type"] == "tool_use" for block in content): - tool_calls = extract_tool_calls(content) + tool_calls, text = extract_tool_calls_and_text(content) else: if provider == "ai21":