diff --git a/src/api/providers/__tests__/bedrock-invokedModelId.test.ts b/src/api/providers/__tests__/bedrock-invokedModelId.test.ts new file mode 100644 index 00000000000..eb95227507a --- /dev/null +++ b/src/api/providers/__tests__/bedrock-invokedModelId.test.ts @@ -0,0 +1,313 @@ +// Mock AWS SDK credential providers +jest.mock("@aws-sdk/credential-providers", () => ({ + fromIni: jest.fn().mockReturnValue({ + accessKeyId: "profile-access-key", + secretAccessKey: "profile-secret-key", + }), +})) + +import { AwsBedrockHandler, StreamEvent } from "../bedrock" +import { ApiHandlerOptions } from "../../../shared/api" +import { BedrockRuntimeClient } from "@aws-sdk/client-bedrock-runtime" + +describe("AwsBedrockHandler with invokedModelId", () => { + let mockSend: jest.SpyInstance + + beforeEach(() => { + // Mock the BedrockRuntimeClient.prototype.send method + mockSend = jest.spyOn(BedrockRuntimeClient.prototype, "send").mockImplementation(async () => { + return { + stream: createMockStream([]), + } + }) + }) + + afterEach(() => { + mockSend.mockRestore() + }) + + // Helper function to create a mock async iterable stream + function createMockStream(events: StreamEvent[]) { + return { + [Symbol.asyncIterator]: async function* () { + for (const event of events) { + yield event + } + // Always yield a metadata event at the end + yield { + metadata: { + usage: { + inputTokens: 100, + outputTokens: 200, + }, + }, + } + }, + } + } + + it("should update costModelConfig when invokedModelId is present in the stream", async () => { + // Create a handler with a custom ARN + const mockOptions: ApiHandlerOptions = { + // apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0", + awsAccessKey: "test-access-key", + awsSecretKey: "test-secret-key", + awsRegion: "us-east-1", + awsCustomArn: "arn:aws:bedrock:us-west-2:699475926481:default-prompt-router/anthropic.claude:1", + } + + const handler = new AwsBedrockHandler(mockOptions) + + // Create a spy on the getModel method before mocking it + const getModelSpy = jest.spyOn(handler, "getModelByName") + + // Mock the stream to include an event with invokedModelId and usage metadata + mockSend.mockImplementationOnce(async () => { + return { + stream: createMockStream([ + // First event with invokedModelId and usage metadata + { + trace: { + promptRouter: { + invokedModelId: + "arn:aws:bedrock:us-west-2:699475926481:inference-profile/us.anthropic.claude-3-5-sonnet-20240620-v1:0", + usage: { + inputTokens: 150, + outputTokens: 250, + }, + }, + }, + // Some content events + }, + { + contentBlockStart: { + start: { + text: "Hello", + }, + contentBlockIndex: 0, + }, + }, + { + contentBlockDelta: { + delta: { + text: ", world!", + }, + contentBlockIndex: 0, + }, + }, + ]), + } + }) + + // Create a message generator + const messageGenerator = handler.createMessage("system prompt", [{ role: "user", content: "user message" }]) + + // Collect all yielded events to verify usage events + const events = [] + for await (const event of messageGenerator) { + events.push(event) + } + + // Verify that getModel was called with the correct model name + expect(getModelSpy).toHaveBeenCalledWith("anthropic.claude-3-5-sonnet-20240620-v1:0") + + // Verify that getModel returns the updated model info + const costModel = handler.getModel() + expect(costModel.id).toBe("anthropic.claude-3-5-sonnet-20240620-v1:0") + expect(costModel.info.inputPrice).toBe(3) + + // Verify that a usage event was emitted after updating the costModelConfig + const usageEvents = events.filter((event) => event.type === "usage") + expect(usageEvents.length).toBeGreaterThanOrEqual(1) + + // The last usage event should have the token counts from the metadata + const lastUsageEvent = usageEvents[usageEvents.length - 1] + expect(lastUsageEvent).toEqual({ + type: "usage", + inputTokens: 100, + outputTokens: 200, + }) + }) + + it("should not update costModelConfig when invokedModelId is not present", async () => { + // Create a handler with default settings + const mockOptions: ApiHandlerOptions = { + apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0", + awsAccessKey: "test-access-key", + awsSecretKey: "test-secret-key", + awsRegion: "us-east-1", + } + + const handler = new AwsBedrockHandler(mockOptions) + + // Mock the stream without an invokedModelId event + mockSend.mockImplementationOnce(async () => { + return { + stream: createMockStream([ + // Some content events but no invokedModelId + { + contentBlockStart: { + start: { + text: "Hello", + }, + contentBlockIndex: 0, + }, + }, + { + contentBlockDelta: { + delta: { + text: ", world!", + }, + contentBlockIndex: 0, + }, + }, + ]), + } + }) + + // Mock getModel to return expected values + const getModelSpy = jest.spyOn(handler, "getModel").mockReturnValue({ + id: "anthropic.claude-3-5-sonnet-20241022-v2:0", + info: { + maxTokens: 4096, + contextWindow: 128_000, + supportsPromptCache: false, + supportsImages: true, + }, + }) + + // Create a message generator + const messageGenerator = handler.createMessage("system prompt", [{ role: "user", content: "user message" }]) + + // Consume the generator + for await (const _ of messageGenerator) { + // Just consume the messages + } + + // Verify that getModel returns the original model info + const costModel = handler.getModel() + expect(costModel.id).toBe("anthropic.claude-3-5-sonnet-20241022-v2:0") + + // Verify getModel was not called with a model name parameter + expect(getModelSpy).not.toHaveBeenCalledWith(expect.any(String)) + }) + + it("should handle invalid invokedModelId format gracefully", async () => { + // Create a handler with default settings + const mockOptions: ApiHandlerOptions = { + apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0", + awsAccessKey: "test-access-key", + awsSecretKey: "test-secret-key", + awsRegion: "us-east-1", + } + + const handler = new AwsBedrockHandler(mockOptions) + + // Mock the stream with an invalid invokedModelId + mockSend.mockImplementationOnce(async () => { + return { + stream: createMockStream([ + // Event with invalid invokedModelId format + { + trace: { + promptRouter: { + invokedModelId: "invalid-format-not-an-arn", + }, + }, + }, + // Some content events + { + contentBlockStart: { + start: { + text: "Hello", + }, + contentBlockIndex: 0, + }, + }, + ]), + } + }) + + // Mock getModel to return expected values + const getModelSpy = jest.spyOn(handler, "getModel").mockReturnValue({ + id: "anthropic.claude-3-5-sonnet-20241022-v2:0", + info: { + maxTokens: 4096, + contextWindow: 128_000, + supportsPromptCache: false, + supportsImages: true, + }, + }) + + // Create a message generator + const messageGenerator = handler.createMessage("system prompt", [{ role: "user", content: "user message" }]) + + // Consume the generator + for await (const _ of messageGenerator) { + // Just consume the messages + } + + // Verify that getModel returns the original model info + const costModel = handler.getModel() + expect(costModel.id).toBe("anthropic.claude-3-5-sonnet-20241022-v2:0") + }) + + it("should handle errors during invokedModelId processing", async () => { + // Create a handler with default settings + const mockOptions: ApiHandlerOptions = { + apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0", + awsAccessKey: "test-access-key", + awsSecretKey: "test-secret-key", + awsRegion: "us-east-1", + } + + const handler = new AwsBedrockHandler(mockOptions) + + // Mock the stream with a valid invokedModelId + mockSend.mockImplementationOnce(async () => { + return { + stream: createMockStream([ + // Event with valid invokedModelId + { + trace: { + promptRouter: { + invokedModelId: + "arn:aws:bedrock:us-east-1:123456789:foundation-model/anthropic.claude-3-sonnet-20240229-v1:0", + }, + }, + }, + ]), + } + }) + + // Mock getModel to throw an error when called with the model name + jest.spyOn(handler, "getModel").mockImplementation((modelName?: string) => { + if (modelName === "anthropic.claude-3-sonnet-20240229-v1:0") { + throw new Error("Test error during model lookup") + } + + // Default return value for initial call + return { + id: "anthropic.claude-3-5-sonnet-20241022-v2:0", + info: { + maxTokens: 4096, + contextWindow: 128_000, + supportsPromptCache: false, + supportsImages: true, + }, + } + }) + + // Create a message generator + const messageGenerator = handler.createMessage("system prompt", [{ role: "user", content: "user message" }]) + + // Consume the generator + for await (const _ of messageGenerator) { + // Just consume the messages + } + + // Verify that getModel returns the original model info + const costModel = handler.getModel() + expect(costModel.id).toBe("anthropic.claude-3-5-sonnet-20241022-v2:0") + }) +}) diff --git a/src/api/providers/__tests__/bedrock.test.ts b/src/api/providers/__tests__/bedrock.test.ts index 45d52702376..0094c3f12ba 100644 --- a/src/api/providers/__tests__/bedrock.test.ts +++ b/src/api/providers/__tests__/bedrock.test.ts @@ -327,10 +327,36 @@ describe("AwsBedrockHandler", () => { const modelInfo = customArnHandler.getModel() expect(modelInfo.id).toBe("arn:aws:bedrock:us-east-1::foundation-model/custom-model") expect(modelInfo.info.maxTokens).toBe(4096) - expect(modelInfo.info.contextWindow).toBe(128_000) + expect(modelInfo.info.contextWindow).toBe(200_000) expect(modelInfo.info.supportsPromptCache).toBe(false) }) + it("should correctly identify model info from inference profile ARN", () => { + //this test intentionally uses a model that has different maxTokens, contextWindow and other values than the fall back option in the code + const customArnHandler = new AwsBedrockHandler({ + apiModelId: "meta.llama3-8b-instruct-v1:0", // This will be ignored when awsCustomArn is provided + awsAccessKey: "test-access-key", + awsSecretKey: "test-secret-key", + awsRegion: "us-west-2", + awsCustomArn: + "arn:aws:bedrock:us-west-2:699475926481:inference-profile/us.meta.llama3-8b-instruct-v1:0", + }) + const modelInfo = customArnHandler.getModel() + + // Verify the ARN is used as the model ID + expect(modelInfo.id).toBe( + "arn:aws:bedrock:us-west-2:699475926481:inference-profile/us.meta.llama3-8b-instruct-v1:0", + ) + + //these should not be the default fall back. they should be Llama's config + expect(modelInfo.info.maxTokens).toBe(2048) + expect(modelInfo.info.contextWindow).toBe(4_000) + expect(modelInfo.info.supportsImages).toBe(false) + expect(modelInfo.info.supportsPromptCache).toBe(false) + + // This test highlights that the regex in getModel needs to be updated to handle inference-profile ARNs + }) + it("should use default model when custom-arn is selected but no ARN is provided", () => { const customArnHandler = new AwsBedrockHandler({ apiModelId: "custom-arn", @@ -345,4 +371,163 @@ describe("AwsBedrockHandler", () => { expect(modelInfo.info).toBeDefined() }) }) + + describe("invokedModelId handling", () => { + it("should update costModelConfig when invokedModelId is present in custom ARN scenario", async () => { + const customArnHandler = new AwsBedrockHandler({ + apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0", + awsAccessKey: "test-access-key", + awsSecretKey: "test-secret-key", + awsRegion: "us-east-1", + awsCustomArn: "arn:aws:bedrock:us-east-1:123456789:foundation-model/custom-model", + }) + + const mockStreamEvent = { + trace: { + promptRouter: { + invokedModelId: "arn:aws:bedrock:us-east-1:123456789:foundation-model/custom-model:0", + }, + }, + } + + jest.spyOn(customArnHandler, "getModel").mockReturnValue({ + id: "custom-model", + info: { + maxTokens: 4096, + contextWindow: 128_000, + supportsPromptCache: false, + supportsImages: true, + }, + }) + + await customArnHandler.createMessage("system prompt", [{ role: "user", content: "user message" }]).next() + + expect(customArnHandler.getModel()).toEqual({ + id: "custom-model", + info: { + maxTokens: 4096, + contextWindow: 128_000, + supportsPromptCache: false, + supportsImages: true, + }, + }) + }) + + it("should update costModelConfig when invokedModelId is present in default model scenario", async () => { + handler = new AwsBedrockHandler({ + apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0", + awsAccessKey: "test-access-key", + awsSecretKey: "test-secret-key", + awsRegion: "us-east-1", + }) + + const mockStreamEvent = { + trace: { + promptRouter: { + invokedModelId: "arn:aws:bedrock:us-east-1:123456789:foundation-model/default-model:0", + }, + }, + } + + jest.spyOn(handler, "getModel").mockReturnValue({ + id: "default-model", + info: { + maxTokens: 4096, + contextWindow: 128_000, + supportsPromptCache: false, + supportsImages: true, + }, + }) + + await handler.createMessage("system prompt", [{ role: "user", content: "user message" }]).next() + + expect(handler.getModel()).toEqual({ + id: "default-model", + info: { + maxTokens: 4096, + contextWindow: 128_000, + supportsPromptCache: false, + supportsImages: true, + }, + }) + }) + + it("should not update costModelConfig when invokedModelId is not present", async () => { + handler = new AwsBedrockHandler({ + apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0", + awsAccessKey: "test-access-key", + awsSecretKey: "test-secret-key", + awsRegion: "us-east-1", + }) + + const mockStreamEvent = { + trace: { + promptRouter: { + // No invokedModelId present + }, + }, + } + + jest.spyOn(handler, "getModel").mockReturnValue({ + id: "default-model", + info: { + maxTokens: 4096, + contextWindow: 128_000, + supportsPromptCache: false, + supportsImages: true, + }, + }) + + await handler.createMessage("system prompt", [{ role: "user", content: "user message" }]).next() + + expect(handler.getModel()).toEqual({ + id: "default-model", + info: { + maxTokens: 4096, + contextWindow: 128_000, + supportsPromptCache: false, + supportsImages: true, + }, + }) + }) + + it("should not update costModelConfig when invokedModelId cannot be parsed", async () => { + handler = new AwsBedrockHandler({ + apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0", + awsAccessKey: "test-access-key", + awsSecretKey: "test-secret-key", + awsRegion: "us-east-1", + }) + + const mockStreamEvent = { + trace: { + promptRouter: { + invokedModelId: "invalid-arn", + }, + }, + } + + jest.spyOn(handler, "getModel").mockReturnValue({ + id: "default-model", + info: { + maxTokens: 4096, + contextWindow: 128_000, + supportsPromptCache: false, + supportsImages: true, + }, + }) + + await handler.createMessage("system prompt", [{ role: "user", content: "user message" }]).next() + + expect(handler.getModel()).toEqual({ + id: "default-model", + info: { + maxTokens: 4096, + contextWindow: 128_000, + supportsPromptCache: false, + supportsImages: true, + }, + }) + }) + }) }) diff --git a/src/api/providers/bedrock.ts b/src/api/providers/bedrock.ts index 76d93649604..1637fe29f3d 100644 --- a/src/api/providers/bedrock.ts +++ b/src/api/providers/bedrock.ts @@ -3,11 +3,19 @@ import { ConverseStreamCommand, ConverseCommand, BedrockRuntimeClientConfig, + ConverseStreamCommandOutput, } from "@aws-sdk/client-bedrock-runtime" import { fromIni } from "@aws-sdk/credential-providers" import { Anthropic } from "@anthropic-ai/sdk" import { SingleCompletionHandler } from "../" -import { ApiHandlerOptions, BedrockModelId, ModelInfo, bedrockDefaultModelId, bedrockModels } from "../../shared/api" +import { + ApiHandlerOptions, + BedrockModelId, + ModelInfo, + bedrockDefaultModelId, + bedrockModels, + bedrockDefaultPromptRouterModelId, +} from "../../shared/api" import { ApiStream } from "../transform/stream" import { convertToBedrockConverseMessages } from "../transform/bedrock-converse-format" import { BaseProvider } from "./base-provider" @@ -21,7 +29,8 @@ import { logger } from "../../utils/logging" */ function validateBedrockArn(arn: string, region?: string) { // Validate ARN format - const arnRegex = /^arn:aws:bedrock:([^:]+):(\d+):(foundation-model|provisioned-model|default-prompt-router)\/(.+)$/ + const arnRegex = + /^arn:aws:bedrock:([^:]+):(\d+):(foundation-model|provisioned-model|default-prompt-router|prompt-router)\/(.+)$/ const match = arn.match(arnRegex) if (!match) { @@ -86,12 +95,27 @@ export interface StreamEvent { latencyMs: number } } + trace?: { + promptRouter?: { + invokedModelId?: string + usage?: { + inputTokens: number + outputTokens: number + totalTokens?: number // Made optional since we don't use it + } + } + } } export class AwsBedrockHandler extends BaseProvider implements SingleCompletionHandler { protected options: ApiHandlerOptions private client: BedrockRuntimeClient + private costModelConfig: { id: BedrockModelId | string; info: ModelInfo } = { + id: "", + info: { maxTokens: 0, contextWindow: 0, supportsPromptCache: false, supportsImages: false }, + } + constructor(options: ApiHandlerOptions) { super() this.options = options @@ -141,8 +165,7 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH } override async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream { - const modelConfig = this.getModel() - + let modelConfig = this.getModel() // Handle cross-region inference let modelId: string @@ -250,8 +273,8 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH continue } - // Handle metadata events first - if (streamEvent.metadata?.usage) { + // Handle metadata events first. + if (streamEvent?.metadata?.usage) { yield { type: "usage", inputTokens: streamEvent.metadata.usage.inputTokens || 0, @@ -260,6 +283,37 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH continue } + if (streamEvent?.trace?.promptRouter?.invokedModelId) { + try { + const invokedModelId = streamEvent.trace.promptRouter.invokedModelId + const modelMatch = invokedModelId.match(/\/([^\/]+)(?::|$)/) + if (modelMatch && modelMatch[1]) { + let modelName = modelMatch[1] + + // Get a new modelConfig from getModel() using invokedModelId.. remove the region first + let region = modelName.slice(0, 3) + + if (region === "us." || region === "eu.") modelName = modelName.slice(3) + this.costModelConfig = this.getModelByName(modelName) + } + + // Handle metadata events for the promptRouter. + if (streamEvent?.trace?.promptRouter?.usage) { + yield { + type: "usage", + inputTokens: streamEvent?.trace?.promptRouter?.usage?.inputTokens || 0, + outputTokens: streamEvent?.trace?.promptRouter?.usage?.outputTokens || 0, + } + continue + } + } catch (error) { + logger.error("Error handling Bedrock invokedModelId", { + ctx: "bedrock", + error: error instanceof Error ? error : String(error), + }) + } + } + // Handle message start if (streamEvent.messageStart) { continue @@ -282,7 +336,6 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH } continue } - // Handle message stop if (streamEvent.messageStop) { continue @@ -428,122 +481,75 @@ Please check: } } + //Prompt Router responses come back in a different sequence and the yield calls are not resulting in costs getting updated + getModelByName(modelName: string): { id: BedrockModelId | string; info: ModelInfo } { + // Try to find the model in bedrockModels + if (modelName in bedrockModels) { + const id = modelName as BedrockModelId + + //Do a deep copy of the model info so that later in the code the model id and maxTokens can be set. + // The bedrockModels array is a constant and updating the model ID from the returned invokedModelID value + // in a prompt router response isn't possible on the constant. + let model = JSON.parse(JSON.stringify(bedrockModels[id])) + + // If modelMaxTokens is explicitly set in options, override the default + if (this.options.modelMaxTokens && this.options.modelMaxTokens > 0) { + model.maxTokens = this.options.modelMaxTokens + } + + return { id, info: model } + } + + return { id: bedrockDefaultModelId, info: bedrockModels[bedrockDefaultModelId] } + } + override getModel(): { id: BedrockModelId | string; info: ModelInfo } { + if (this.costModelConfig.id.trim().length > 0) { + return this.costModelConfig + } + // If custom ARN is provided, use it if (this.options.awsCustomArn) { - // Custom ARNs should not be modified with region prefixes - // as they already contain the full resource path - - // Check if the ARN contains information about the model type - // This helps set appropriate token limits for models behind prompt routers - const arnLower = this.options.awsCustomArn.toLowerCase() - - // Determine model info based on ARN content - let modelInfo: ModelInfo - - if (arnLower.includes("claude-3-7-sonnet") || arnLower.includes("claude-3.7-sonnet")) { - // Claude 3.7 Sonnet has 8192 tokens in Bedrock - modelInfo = { - maxTokens: 8192, - contextWindow: 200_000, - supportsPromptCache: false, - supportsImages: true, - supportsComputerUse: true, - } - } else if (arnLower.includes("claude-3-5-sonnet") || arnLower.includes("claude-3.5-sonnet")) { - // Claude 3.5 Sonnet has 8192 tokens in Bedrock - modelInfo = { - maxTokens: 8192, - contextWindow: 200_000, - supportsPromptCache: false, - supportsImages: true, - supportsComputerUse: true, - } - } else if (arnLower.includes("claude-3-opus") || arnLower.includes("claude-3.0-opus")) { - // Claude 3 Opus has 4096 tokens in Bedrock - modelInfo = { - maxTokens: 4096, - contextWindow: 200_000, - supportsPromptCache: false, - supportsImages: true, - } - } else if (arnLower.includes("claude-3-haiku") || arnLower.includes("claude-3.0-haiku")) { - // Claude 3 Haiku has 4096 tokens in Bedrock - modelInfo = { - maxTokens: 4096, - contextWindow: 200_000, - supportsPromptCache: false, - supportsImages: true, - } - } else if (arnLower.includes("claude-3-5-haiku") || arnLower.includes("claude-3.5-haiku")) { - // Claude 3.5 Haiku has 8192 tokens in Bedrock - modelInfo = { - maxTokens: 8192, - contextWindow: 200_000, - supportsPromptCache: false, - supportsImages: false, - } - } else if (arnLower.includes("claude")) { - // Generic Claude model with conservative token limit - modelInfo = { - maxTokens: 4096, - contextWindow: 128_000, - supportsPromptCache: false, - supportsImages: true, - } - } else if (arnLower.includes("llama3") || arnLower.includes("llama-3")) { - // Llama 3 models typically have 8192 tokens in Bedrock - modelInfo = { - maxTokens: 8192, - contextWindow: 128_000, - supportsPromptCache: false, - supportsImages: arnLower.includes("90b") || arnLower.includes("11b"), - } - } else if (arnLower.includes("nova-pro")) { - // Amazon Nova Pro - modelInfo = { - maxTokens: 5000, - contextWindow: 300_000, - supportsPromptCache: false, - supportsImages: true, - } - } else { - // Default for unknown models or prompt routers - modelInfo = { - maxTokens: 4096, - contextWindow: 128_000, - supportsPromptCache: false, - supportsImages: true, + // Extract the model name from the ARN + const arnMatch = this.options.awsCustomArn.match( + /^arn:aws:bedrock:([^:]+):(\d+):(inference-profile|foundation-model|provisioned-model)\/(.+)$/, + ) + + let modelName = arnMatch ? arnMatch[4] : "" + if (modelName) { + let region = modelName.slice(0, 3) + if (region === "us." || region === "eu.") modelName = modelName.slice(3) + + let modelData = this.getModelByName(modelName) + modelData.id = this.options.awsCustomArn + + if (modelData) { + return modelData } } - // If modelMaxTokens is explicitly set in options, override the default - if (this.options.modelMaxTokens && this.options.modelMaxTokens > 0) { - modelInfo.maxTokens = this.options.modelMaxTokens - } + // An ARN was used, but no model info match found, use default values based on common patterns + let model = this.getModelByName(bedrockDefaultPromptRouterModelId) + // For custom ARNs, always return the specific values expected by tests return { id: this.options.awsCustomArn, - info: modelInfo, + info: model.info, } } - const modelId = this.options.apiModelId - if (modelId) { + if (this.options.apiModelId) { // Special case for custom ARN option - if (modelId === "custom-arn") { + if (this.options.apiModelId === "custom-arn") { // This should not happen as we should have awsCustomArn set // but just in case, return a default model - return { - id: bedrockDefaultModelId, - info: bedrockModels[bedrockDefaultModelId], - } + return this.getModelByName(bedrockDefaultModelId) } - // For tests, allow any model ID + // For tests, allow any model ID (but not custom ARNs, which are handled above) if (process.env.NODE_ENV === "test") { return { - id: modelId, + id: this.options.apiModelId, info: { maxTokens: 5000, contextWindow: 128_000, @@ -552,15 +558,9 @@ Please check: } } // For production, validate against known models - if (modelId in bedrockModels) { - const id = modelId as BedrockModelId - return { id, info: bedrockModels[id] } - } - } - return { - id: bedrockDefaultModelId, - info: bedrockModels[bedrockDefaultModelId], + return this.getModelByName(this.options.apiModelId) } + return this.getModelByName(bedrockDefaultModelId) } async completePrompt(prompt: string): Promise { @@ -573,10 +573,6 @@ Please check: // For custom ARNs, use the ARN directly without modification if (this.options.awsCustomArn) { modelId = modelConfig.id - logger.debug("Using custom ARN in completePrompt", { - ctx: "bedrock", - customArn: this.options.awsCustomArn, - }) // Validate ARN format and check region match const clientRegion = this.client.config.region as string diff --git a/src/shared/api.ts b/src/shared/api.ts index 8a034d8d6ca..cb0410096c2 100644 --- a/src/shared/api.ts +++ b/src/shared/api.ts @@ -246,6 +246,8 @@ export interface MessageContent { export type BedrockModelId = keyof typeof bedrockModels export const bedrockDefaultModelId: BedrockModelId = "anthropic.claude-3-7-sonnet-20250219-v1:0" +export const bedrockDefaultPromptRouterModelId: BedrockModelId = "anthropic.claude-3-sonnet-20240229-v1:0" + // March, 12 2025 - updated prices to match US-West-2 list price shown at https://aws.amazon.com/bedrock/pricing/ // including older models that are part of the default prompt routers AWS enabled for GA of the promot router feature export const bedrockModels = {