Skip to content

Commit 63a7338

Browse files
committed
feat: working and correspoding integration test
1 parent 5f1ea17 commit 63a7338

File tree

2 files changed

+27
-19
lines changed

2 files changed

+27
-19
lines changed

libs/aws/langchain_aws/function_calling.py

Lines changed: 5 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from langchain_core.messages import ToolCall
1818
from langchain_core.output_parsers import BaseGenerationOutputParser
1919
from langchain_core.outputs import ChatGeneration, Generation
20+
from langchain_core.prompts.chat import AIMessage
2021
from langchain_core.pydantic_v1 import BaseModel
2122
from langchain_core.tools import BaseTool
2223
from langchain_core.utils.function_calling import convert_to_openai_tool
@@ -177,17 +178,13 @@ def parse_result(self, result: List[Generation], *, partial: bool = False) -> An
177178
if not result or not isinstance(result[0], ChatGeneration):
178179
return None if self.first_tool_only else []
179180
message = result[0].message
180-
if isinstance(message.content, str):
181+
if len(message.content) > 0:
181182
tool_calls: List = []
182183
else:
183-
content: List = message.content
184-
_tool_calls = [dict(tc) for tc in extract_tool_calls(content)]
184+
content = cast(AIMessage, message)
185+
_tool_calls = [dict(tc) for tc in content.tool_calls]
185186
# Map tool call id to index
186-
id_to_index = {
187-
block["id"]: i
188-
for i, block in enumerate(content)
189-
if block["type"] == "tool_use"
190-
}
187+
id_to_index = {block["id"]: i for i, block in enumerate(_tool_calls)}
191188
tool_calls = [{**tc, "index": id_to_index[tc["id"]]} for tc in _tool_calls]
192189
if self.pydantic_schemas:
193190
tool_calls = [self._pydantic_parse(tc) for tc in tool_calls]
@@ -208,17 +205,6 @@ def _pydantic_parse(self, tool_call: dict) -> BaseModel:
208205
return cls_(**tool_call["args"])
209206

210207

211-
def extract_tool_calls(content: List[dict]) -> List[ToolCall]:
212-
tool_calls = []
213-
for block in content:
214-
if block["type"] != "tool_use":
215-
continue
216-
tool_calls.append(
217-
ToolCall(name=block["name"], args=block["input"], id=block["id"])
218-
)
219-
return tool_calls
220-
221-
222208
def convert_to_anthropic_tool(
223209
tool: Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool],
224210
) -> AnthropicTool:

libs/aws/tests/integration_tests/chat_models/test_bedrock.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,28 @@ class GetWeather(BaseModel):
190190
location: str = Field(..., description="The city and state")
191191

192192

193+
class AnswerWithJustification(BaseModel):
194+
"""An answer to the user question along with justification for the answer."""
195+
196+
answer: str
197+
justification: str
198+
199+
200+
@pytest.mark.scheduled
201+
def test_structured_output() -> None:
202+
chat = ChatBedrock(
203+
model_id="anthropic.claude-3-sonnet-20240229-v1:0",
204+
model_kwargs={"temperature": 0.001},
205+
) # type: ignore[call-arg]
206+
structured_llm = chat.with_structured_output(AnswerWithJustification)
207+
208+
response = structured_llm.invoke(
209+
"What weighs more a pound of bricks or a pound of feathers"
210+
)
211+
212+
assert isinstance(response, AnswerWithJustification)
213+
214+
193215
@pytest.mark.scheduled
194216
def test_tool_use_call_invoke() -> None:
195217
chat = ChatBedrock(

0 commit comments

Comments
 (0)