Skip to content

Commit 625075a

Browse files
authored
fix(aws): add tool_choice support for claude 4 models (#8601)
2 parents 0d865c7 + 8a9e83b commit 625075a

File tree

4 files changed

+76
-50
lines changed

4 files changed

+76
-50
lines changed

libs/langchain-aws/src/chat_models.ts

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ import {
5050
handleConverseStreamMetadata,
5151
handleConverseStreamContentBlockStart,
5252
BedrockConverseToolChoice,
53+
supportedToolChoiceValuesForModel,
5354
} from "./common.js";
5455
import {
5556
ChatBedrockConverseToolType,
@@ -732,13 +733,9 @@ export class ChatBedrockConverse
732733
this.performanceConfig = rest?.performanceConfig;
733734

734735
if (rest?.supportsToolChoiceValues === undefined) {
735-
if (this.model.includes("claude-3")) {
736-
this.supportsToolChoiceValues = ["auto", "any", "tool"];
737-
} else if (this.model.includes("mistral-large")) {
738-
this.supportsToolChoiceValues = ["auto", "any"];
739-
} else {
740-
this.supportsToolChoiceValues = undefined;
741-
}
736+
this.supportsToolChoiceValues = supportedToolChoiceValuesForModel(
737+
this.model
738+
);
742739
} else {
743740
this.supportsToolChoiceValues = rest.supportsToolChoiceValues;
744741
}

libs/langchain-aws/src/common.ts

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -961,3 +961,20 @@ export function concatenateLangchainReasoningBlocks(
961961
}
962962
return concatenatedBlocks;
963963
}
964+
965+
export function supportedToolChoiceValuesForModel(
966+
model: string
967+
): Array<"auto" | "any" | "tool"> | undefined {
968+
if (
969+
model.includes("claude-3") ||
970+
model.includes("claude-4") ||
971+
model.includes("claude-opus-4") ||
972+
model.includes("claude-sonnet-4")
973+
) {
974+
return ["auto", "any", "tool"];
975+
}
976+
if (model.includes("mistral-large")) {
977+
return ["auto", "any"];
978+
}
979+
return undefined;
980+
}

libs/langchain-aws/src/tests/chat_models.int.test.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,7 @@ test("Test ChatBedrockConverse can stream tools", async () => {
286286
test("Test ChatBedrockConverse tool_choice works", async () => {
287287
const model = new ChatBedrockConverse({
288288
...baseConstructorArgs,
289+
model: "us.anthropic.claude-sonnet-4-20250514-v1:0",
289290
});
290291
const tools = [
291292
tool(

libs/langchain-aws/src/tests/chat_models.test.ts

Lines changed: 54 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -498,50 +498,61 @@ describe("tool_choice works for supported models", () => {
498498
);
499499
});
500500

501-
it("should bind tool_choice when using WSO with supported models", async () => {
502-
// Claude 3 should NOT throw is using WSO & it should have `tool_choice` bound.
503-
const claude3Model = new ChatBedrockConverse({
504-
...baseConstructorArgs,
505-
model: "anthropic.claude-3-5-sonnet-20240620-v1:0",
506-
// We are not passing the `supportsToolChoiceValues` arg here as
507-
// it should be inferred from the model name.
508-
});
509-
const claude3ModelWSO = claude3Model.withStructuredOutput(tool.schema, {
510-
name: tool.name,
511-
});
512-
expect(claude3ModelWSO).toBeDefined();
513-
const claude3ModelWSOAsJSON = claude3ModelWSO.toJSON();
514-
if (!("kwargs" in claude3ModelWSOAsJSON)) {
515-
throw new Error("kwargs not found in claude3ModelWSOAsJSON");
501+
it.each([
502+
"anthropic.claude-3-5-sonnet-20240620-v1:0",
503+
"anthropic.claude-sonnet-4-20250514-v1:0",
504+
])(
505+
"should bind tool_choice when using WSO with model that supports tool choice: %s",
506+
(model) => {
507+
// Claude 3 should NOT throw is using WSO & it should have `tool_choice` bound.
508+
const claude3Model = new ChatBedrockConverse({
509+
...baseConstructorArgs,
510+
model,
511+
// We are not passing the `supportsToolChoiceValues` arg here as
512+
// it should be inferred from the model name.
513+
});
514+
const claude3ModelWSO = claude3Model.withStructuredOutput(tool.schema, {
515+
name: tool.name,
516+
});
517+
expect(claude3ModelWSO).toBeDefined();
518+
const claude3ModelWSOAsJSON = claude3ModelWSO.toJSON();
519+
if (!("kwargs" in claude3ModelWSOAsJSON)) {
520+
throw new Error("kwargs not found in claude3ModelWSOAsJSON");
521+
}
522+
expect(claude3ModelWSOAsJSON.kwargs.bound.first.config).toHaveProperty(
523+
"tool_choice"
524+
);
525+
expect(claude3ModelWSOAsJSON.kwargs.bound.first.config.tool_choice).toBe(
526+
tool.name
527+
);
516528
}
517-
expect(claude3ModelWSOAsJSON.kwargs.bound.first.config).toHaveProperty(
518-
"tool_choice"
519-
);
520-
expect(claude3ModelWSOAsJSON.kwargs.bound.first.config.tool_choice).toBe(
521-
tool.name
522-
);
529+
);
523530

524-
// Mistral (not mistral large) should NOT throw is using WSO
525-
const mistralModel = new ChatBedrockConverse({
526-
...baseConstructorArgs,
527-
model: "mistral.mistral-large-2407-v1:0",
528-
// We are not passing the `supportsToolChoiceValues` arg here as
529-
// it should be inferred from the model name.
530-
});
531-
const mistralModelWSO = mistralModel.withStructuredOutput(tool.schema, {
532-
name: tool.name,
533-
});
534-
expect(mistralModelWSO).toBeDefined();
535-
const mistralModelWSOAsJSON = mistralModelWSO.toJSON();
536-
if (!("kwargs" in mistralModelWSOAsJSON)) {
537-
throw new Error("kwargs not found in mistralModelWSOAsJSON");
531+
it.each(["mistral.mistral-large-2407-v1:0"])(
532+
"should bind tool_choice when using WSO with model that doesn't support tool choice: %s",
533+
(model) => {
534+
// Mistral (not mistral large) should NOT throw is using WSO
535+
const mistralModel = new ChatBedrockConverse({
536+
...baseConstructorArgs,
537+
model,
538+
// We are not passing the `supportsToolChoiceValues` arg here as
539+
// it should be inferred from the model name.
540+
});
541+
const mistralModelWSO = mistralModel.withStructuredOutput(tool.schema, {
542+
name: tool.name,
543+
});
544+
expect(mistralModelWSO).toBeDefined();
545+
const mistralModelWSOAsJSON = mistralModelWSO.toJSON();
546+
if (!("kwargs" in mistralModelWSOAsJSON)) {
547+
throw new Error("kwargs not found in mistralModelWSOAsJSON");
548+
}
549+
expect(mistralModelWSOAsJSON.kwargs.bound.first.config).toHaveProperty(
550+
"tool_choice"
551+
);
552+
// Mistral large only supports "auto" and "any" for tool_choice, not the actual tool name
553+
expect(mistralModelWSOAsJSON.kwargs.bound.first.config.tool_choice).toBe(
554+
"any"
555+
);
538556
}
539-
expect(mistralModelWSOAsJSON.kwargs.bound.first.config).toHaveProperty(
540-
"tool_choice"
541-
);
542-
// Mistral large only supports "auto" and "any" for tool_choice, not the actual tool name
543-
expect(mistralModelWSOAsJSON.kwargs.bound.first.config.tool_choice).toBe(
544-
"any"
545-
);
546-
});
557+
);
547558
});

0 commit comments

Comments
 (0)