From d8df9a5e2c722e6ffdc556256ef60ae33664453a Mon Sep 17 00:00:00 2001 From: Smartsheet-JB-Brown Date: Tue, 11 Mar 2025 21:58:07 -0700 Subject: [PATCH 1/5] Cost display updating for Bedrock custom ARNs that are prompt routers --- pr-description.md | 57 ++++ src/__mocks__/jest.setup.ts | 8 +- .../__tests__/bedrock-createMessage.test.ts | 151 ++++++++ .../__tests__/bedrock-invokedModelId.test.ts | 313 +++++++++++++++++ src/api/providers/__tests__/bedrock.test.ts | 163 ++++++++- src/api/providers/bedrock.ts | 323 ++++++++++++------ src/shared/api.ts | 1 + 7 files changed, 902 insertions(+), 114 deletions(-) create mode 100644 pr-description.md create mode 100644 src/api/providers/__tests__/bedrock-createMessage.test.ts create mode 100644 src/api/providers/__tests__/bedrock-invokedModelId.test.ts diff --git a/pr-description.md b/pr-description.md new file mode 100644 index 0000000000..12e085cf5d --- /dev/null +++ b/pr-description.md @@ -0,0 +1,57 @@ +# AWS Bedrock Model Updates and Cost Calculation Improvements + +## Overview + +This pull request updates the AWS Bedrock model definitions with the latest pricing information and improves cost calculation for API providers. The changes ensure accurate cost tracking for both standard API calls and prompt cache operations. + +## Changes + +### 1. Updated AWS Bedrock Model Definitions + +- Updated pricing information for all AWS Bedrock models to match the published list prices for US-West-2 as of March 11, 2025 +- Added support for new models: + - Amazon Nova Pro with latency optimized inference + - Meta Llama 3.3 (70B) Instruct + - Meta Llama 3.2 models (90B, 11B, 3B, 1B) + - Meta Llama 3.1 models (405B, 70B, 8B) +- Added detailed model descriptions for better user understanding +- Added `supportsComputerUse` flag to relevant models + +### 2. Enhanced Cost Calculation + +- Implemented a unified internal cost calculation function that handles: + - Base input token costs + - Output token costs + - Cache creation (writes) costs + - Cache read costs +- Created two specialized cost calculation functions: + - `calculateApiCostAnthropic`: For Anthropic-compliant usage where input tokens count does NOT include cached tokens + - `calculateApiCostOpenAI`: For OpenAI-compliant usage where input tokens count INCLUDES cached tokens + +### 3. Improved Custom ARN Handling in Bedrock Provider + +- Enhanced model detection for custom ARNs by implementing a normalized string comparison +- Added better error handling and user feedback for custom ARN issues +- Improved region handling for cross-region inference +- Fixed AWS cost calculation when using a custom ARN, including ARNs for intelligent prompt routing + +### 4. Comprehensive Test Coverage + +- Added extensive unit tests for both cost calculation functions +- Tests cover various scenarios including: + - Basic input/output costs + - Cache writes costs + - Cache reads costs + - Combined cost calculations + - Edge cases (missing prices, zero tokens, undefined values) + +## Benefits + +1. **Accurate Cost Tracking**: Users will see more accurate cost estimates for their API usage, including prompt cache operations +2. **Support for Latest Models**: Access to the newest AWS Bedrock models with correct pricing information +3. **Better Error Handling**: Improved feedback when using custom ARNs or encountering region-specific issues +4. **Consistent Cost Calculation**: Standardized approach to cost calculation across different API providers + +## Testing + +All tests are passing, including the new cost calculation tests and updated Bedrock provider tests. diff --git a/src/__mocks__/jest.setup.ts b/src/__mocks__/jest.setup.ts index 836279bfe4..61077be6d8 100644 --- a/src/__mocks__/jest.setup.ts +++ b/src/__mocks__/jest.setup.ts @@ -1,13 +1,17 @@ // Mock the logger globally for all tests jest.mock("../utils/logging", () => ({ logger: { - debug: jest.fn(), + debug: jest.fn().mockImplementation((message, meta) => { + console.log(`DEBUG: ${message}`, meta ? JSON.stringify(meta) : "") + }), info: jest.fn(), warn: jest.fn(), error: jest.fn(), fatal: jest.fn(), child: jest.fn().mockReturnValue({ - debug: jest.fn(), + debug: jest.fn().mockImplementation((message, meta) => { + console.log(`DEBUG: ${message}`, meta ? JSON.stringify(meta) : "") + }), info: jest.fn(), warn: jest.fn(), error: jest.fn(), diff --git a/src/api/providers/__tests__/bedrock-createMessage.test.ts b/src/api/providers/__tests__/bedrock-createMessage.test.ts new file mode 100644 index 0000000000..7e69bd74c4 --- /dev/null +++ b/src/api/providers/__tests__/bedrock-createMessage.test.ts @@ -0,0 +1,151 @@ +// Mock AWS SDK credential providers +jest.mock("@aws-sdk/credential-providers", () => ({ + fromIni: jest.fn().mockReturnValue({ + accessKeyId: "profile-access-key", + secretAccessKey: "profile-secret-key", + }), +})) + +import { AwsBedrockHandler, StreamEvent } from "../bedrock" +import { ApiHandlerOptions } from "../../../shared/api" +import { BedrockRuntimeClient } from "@aws-sdk/client-bedrock-runtime" +import { logger } from "../../../utils/logging" + +describe("AwsBedrockHandler createMessage", () => { + let mockSend: jest.SpyInstance + + beforeEach(() => { + // Mock the BedrockRuntimeClient.prototype.send method + mockSend = jest.spyOn(BedrockRuntimeClient.prototype, "send").mockImplementation(async () => { + return { + stream: createMockStream([]), + } + }) + }) + + afterEach(() => { + mockSend.mockRestore() + }) + + // Helper function to create a mock async iterable stream + function createMockStream(events: StreamEvent[]) { + return { + [Symbol.asyncIterator]: async function* () { + for (const event of events) { + yield event + } + // Always yield a metadata event at the end + yield { + metadata: { + usage: { + inputTokens: 100, + outputTokens: 200, + }, + }, + } + }, + } + } + + it("should log debug information during createMessage with custom ARN", async () => { + // Create a handler with a custom ARN + const mockOptions: ApiHandlerOptions = { + apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0", + awsAccessKey: "test-access-key", + awsSecretKey: "test-secret-key", + awsRegion: "us-east-1", + awsCustomArn: "arn:aws:bedrock:us-east-1:123456789:foundation-model/custom-model", + } + + const handler = new AwsBedrockHandler(mockOptions) + + // Mock the stream to include various events that trigger debug logs + mockSend.mockImplementationOnce(async () => { + return { + stream: createMockStream([ + // Event with invokedModelId + { + trace: { + promptRouter: { + invokedModelId: + "arn:aws:bedrock:us-east-1:123456789:foundation-model/anthropic.claude-3-sonnet-20240229-v1:0", + }, + }, + }, + // Content events + { + contentBlockStart: { + start: { + text: "Hello", + }, + contentBlockIndex: 0, + }, + }, + { + contentBlockDelta: { + delta: { + text: ", world!", + }, + contentBlockIndex: 0, + }, + }, + ]), + } + }) + + // Create a message generator + const messageGenerator = handler.createMessage("system prompt", [{ role: "user", content: "user message" }]) + + // Collect all yielded events + const events = [] + for await (const event of messageGenerator) { + events.push(event) + } + + // Verify that events were yielded + expect(events.length).toBeGreaterThan(0) + + // Verify that debug logs were called + expect(logger.debug).toHaveBeenCalledWith( + "Using custom ARN for Bedrock request", + expect.objectContaining({ + ctx: "bedrock", + customArn: mockOptions.awsCustomArn, + }), + ) + + expect(logger.debug).toHaveBeenCalledWith( + "Bedrock invokedModelId detected", + expect.objectContaining({ + ctx: "bedrock", + invokedModelId: + "arn:aws:bedrock:us-east-1:123456789:foundation-model/anthropic.claude-3-sonnet-20240229-v1:0", + }), + ) + }) + + it("should log debug information during createMessage with cross-region inference", async () => { + // Create a handler with cross-region inference + const mockOptions: ApiHandlerOptions = { + apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0", + awsAccessKey: "test-access-key", + awsSecretKey: "test-secret-key", + awsRegion: "us-east-1", + awsUseCrossRegionInference: true, + } + + const handler = new AwsBedrockHandler(mockOptions) + + // Create a message generator + const messageGenerator = handler.createMessage("system prompt", [{ role: "user", content: "user message" }]) + + // Collect all yielded events + const events = [] + for await (const event of messageGenerator) { + events.push(event) + } + + // Verify that events were yielded + expect(events.length).toBeGreaterThan(0) + }) +}) diff --git a/src/api/providers/__tests__/bedrock-invokedModelId.test.ts b/src/api/providers/__tests__/bedrock-invokedModelId.test.ts new file mode 100644 index 0000000000..eb95227507 --- /dev/null +++ b/src/api/providers/__tests__/bedrock-invokedModelId.test.ts @@ -0,0 +1,313 @@ +// Mock AWS SDK credential providers +jest.mock("@aws-sdk/credential-providers", () => ({ + fromIni: jest.fn().mockReturnValue({ + accessKeyId: "profile-access-key", + secretAccessKey: "profile-secret-key", + }), +})) + +import { AwsBedrockHandler, StreamEvent } from "../bedrock" +import { ApiHandlerOptions } from "../../../shared/api" +import { BedrockRuntimeClient } from "@aws-sdk/client-bedrock-runtime" + +describe("AwsBedrockHandler with invokedModelId", () => { + let mockSend: jest.SpyInstance + + beforeEach(() => { + // Mock the BedrockRuntimeClient.prototype.send method + mockSend = jest.spyOn(BedrockRuntimeClient.prototype, "send").mockImplementation(async () => { + return { + stream: createMockStream([]), + } + }) + }) + + afterEach(() => { + mockSend.mockRestore() + }) + + // Helper function to create a mock async iterable stream + function createMockStream(events: StreamEvent[]) { + return { + [Symbol.asyncIterator]: async function* () { + for (const event of events) { + yield event + } + // Always yield a metadata event at the end + yield { + metadata: { + usage: { + inputTokens: 100, + outputTokens: 200, + }, + }, + } + }, + } + } + + it("should update costModelConfig when invokedModelId is present in the stream", async () => { + // Create a handler with a custom ARN + const mockOptions: ApiHandlerOptions = { + // apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0", + awsAccessKey: "test-access-key", + awsSecretKey: "test-secret-key", + awsRegion: "us-east-1", + awsCustomArn: "arn:aws:bedrock:us-west-2:699475926481:default-prompt-router/anthropic.claude:1", + } + + const handler = new AwsBedrockHandler(mockOptions) + + // Create a spy on the getModel method before mocking it + const getModelSpy = jest.spyOn(handler, "getModelByName") + + // Mock the stream to include an event with invokedModelId and usage metadata + mockSend.mockImplementationOnce(async () => { + return { + stream: createMockStream([ + // First event with invokedModelId and usage metadata + { + trace: { + promptRouter: { + invokedModelId: + "arn:aws:bedrock:us-west-2:699475926481:inference-profile/us.anthropic.claude-3-5-sonnet-20240620-v1:0", + usage: { + inputTokens: 150, + outputTokens: 250, + }, + }, + }, + // Some content events + }, + { + contentBlockStart: { + start: { + text: "Hello", + }, + contentBlockIndex: 0, + }, + }, + { + contentBlockDelta: { + delta: { + text: ", world!", + }, + contentBlockIndex: 0, + }, + }, + ]), + } + }) + + // Create a message generator + const messageGenerator = handler.createMessage("system prompt", [{ role: "user", content: "user message" }]) + + // Collect all yielded events to verify usage events + const events = [] + for await (const event of messageGenerator) { + events.push(event) + } + + // Verify that getModel was called with the correct model name + expect(getModelSpy).toHaveBeenCalledWith("anthropic.claude-3-5-sonnet-20240620-v1:0") + + // Verify that getModel returns the updated model info + const costModel = handler.getModel() + expect(costModel.id).toBe("anthropic.claude-3-5-sonnet-20240620-v1:0") + expect(costModel.info.inputPrice).toBe(3) + + // Verify that a usage event was emitted after updating the costModelConfig + const usageEvents = events.filter((event) => event.type === "usage") + expect(usageEvents.length).toBeGreaterThanOrEqual(1) + + // The last usage event should have the token counts from the metadata + const lastUsageEvent = usageEvents[usageEvents.length - 1] + expect(lastUsageEvent).toEqual({ + type: "usage", + inputTokens: 100, + outputTokens: 200, + }) + }) + + it("should not update costModelConfig when invokedModelId is not present", async () => { + // Create a handler with default settings + const mockOptions: ApiHandlerOptions = { + apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0", + awsAccessKey: "test-access-key", + awsSecretKey: "test-secret-key", + awsRegion: "us-east-1", + } + + const handler = new AwsBedrockHandler(mockOptions) + + // Mock the stream without an invokedModelId event + mockSend.mockImplementationOnce(async () => { + return { + stream: createMockStream([ + // Some content events but no invokedModelId + { + contentBlockStart: { + start: { + text: "Hello", + }, + contentBlockIndex: 0, + }, + }, + { + contentBlockDelta: { + delta: { + text: ", world!", + }, + contentBlockIndex: 0, + }, + }, + ]), + } + }) + + // Mock getModel to return expected values + const getModelSpy = jest.spyOn(handler, "getModel").mockReturnValue({ + id: "anthropic.claude-3-5-sonnet-20241022-v2:0", + info: { + maxTokens: 4096, + contextWindow: 128_000, + supportsPromptCache: false, + supportsImages: true, + }, + }) + + // Create a message generator + const messageGenerator = handler.createMessage("system prompt", [{ role: "user", content: "user message" }]) + + // Consume the generator + for await (const _ of messageGenerator) { + // Just consume the messages + } + + // Verify that getModel returns the original model info + const costModel = handler.getModel() + expect(costModel.id).toBe("anthropic.claude-3-5-sonnet-20241022-v2:0") + + // Verify getModel was not called with a model name parameter + expect(getModelSpy).not.toHaveBeenCalledWith(expect.any(String)) + }) + + it("should handle invalid invokedModelId format gracefully", async () => { + // Create a handler with default settings + const mockOptions: ApiHandlerOptions = { + apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0", + awsAccessKey: "test-access-key", + awsSecretKey: "test-secret-key", + awsRegion: "us-east-1", + } + + const handler = new AwsBedrockHandler(mockOptions) + + // Mock the stream with an invalid invokedModelId + mockSend.mockImplementationOnce(async () => { + return { + stream: createMockStream([ + // Event with invalid invokedModelId format + { + trace: { + promptRouter: { + invokedModelId: "invalid-format-not-an-arn", + }, + }, + }, + // Some content events + { + contentBlockStart: { + start: { + text: "Hello", + }, + contentBlockIndex: 0, + }, + }, + ]), + } + }) + + // Mock getModel to return expected values + const getModelSpy = jest.spyOn(handler, "getModel").mockReturnValue({ + id: "anthropic.claude-3-5-sonnet-20241022-v2:0", + info: { + maxTokens: 4096, + contextWindow: 128_000, + supportsPromptCache: false, + supportsImages: true, + }, + }) + + // Create a message generator + const messageGenerator = handler.createMessage("system prompt", [{ role: "user", content: "user message" }]) + + // Consume the generator + for await (const _ of messageGenerator) { + // Just consume the messages + } + + // Verify that getModel returns the original model info + const costModel = handler.getModel() + expect(costModel.id).toBe("anthropic.claude-3-5-sonnet-20241022-v2:0") + }) + + it("should handle errors during invokedModelId processing", async () => { + // Create a handler with default settings + const mockOptions: ApiHandlerOptions = { + apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0", + awsAccessKey: "test-access-key", + awsSecretKey: "test-secret-key", + awsRegion: "us-east-1", + } + + const handler = new AwsBedrockHandler(mockOptions) + + // Mock the stream with a valid invokedModelId + mockSend.mockImplementationOnce(async () => { + return { + stream: createMockStream([ + // Event with valid invokedModelId + { + trace: { + promptRouter: { + invokedModelId: + "arn:aws:bedrock:us-east-1:123456789:foundation-model/anthropic.claude-3-sonnet-20240229-v1:0", + }, + }, + }, + ]), + } + }) + + // Mock getModel to throw an error when called with the model name + jest.spyOn(handler, "getModel").mockImplementation((modelName?: string) => { + if (modelName === "anthropic.claude-3-sonnet-20240229-v1:0") { + throw new Error("Test error during model lookup") + } + + // Default return value for initial call + return { + id: "anthropic.claude-3-5-sonnet-20241022-v2:0", + info: { + maxTokens: 4096, + contextWindow: 128_000, + supportsPromptCache: false, + supportsImages: true, + }, + } + }) + + // Create a message generator + const messageGenerator = handler.createMessage("system prompt", [{ role: "user", content: "user message" }]) + + // Consume the generator + for await (const _ of messageGenerator) { + // Just consume the messages + } + + // Verify that getModel returns the original model info + const costModel = handler.getModel() + expect(costModel.id).toBe("anthropic.claude-3-5-sonnet-20241022-v2:0") + }) +}) diff --git a/src/api/providers/__tests__/bedrock.test.ts b/src/api/providers/__tests__/bedrock.test.ts index 45d5270237..d29e0ad245 100644 --- a/src/api/providers/__tests__/bedrock.test.ts +++ b/src/api/providers/__tests__/bedrock.test.ts @@ -326,8 +326,8 @@ describe("AwsBedrockHandler", () => { }) const modelInfo = customArnHandler.getModel() expect(modelInfo.id).toBe("arn:aws:bedrock:us-east-1::foundation-model/custom-model") - expect(modelInfo.info.maxTokens).toBe(4096) - expect(modelInfo.info.contextWindow).toBe(128_000) + expect(modelInfo.info.maxTokens).toBe(8192) + expect(modelInfo.info.contextWindow).toBe(200_000) expect(modelInfo.info.supportsPromptCache).toBe(false) }) @@ -345,4 +345,163 @@ describe("AwsBedrockHandler", () => { expect(modelInfo.info).toBeDefined() }) }) + + describe("invokedModelId handling", () => { + it("should update costModelConfig when invokedModelId is present in custom ARN scenario", async () => { + const customArnHandler = new AwsBedrockHandler({ + apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0", + awsAccessKey: "test-access-key", + awsSecretKey: "test-secret-key", + awsRegion: "us-east-1", + awsCustomArn: "arn:aws:bedrock:us-east-1:123456789:foundation-model/custom-model", + }) + + const mockStreamEvent = { + trace: { + promptRouter: { + invokedModelId: "arn:aws:bedrock:us-east-1:123456789:foundation-model/custom-model:0", + }, + }, + } + + jest.spyOn(customArnHandler, "getModel").mockReturnValue({ + id: "custom-model", + info: { + maxTokens: 4096, + contextWindow: 128_000, + supportsPromptCache: false, + supportsImages: true, + }, + }) + + await customArnHandler.createMessage("system prompt", [{ role: "user", content: "user message" }]).next() + + expect(customArnHandler.getModel()).toEqual({ + id: "custom-model", + info: { + maxTokens: 4096, + contextWindow: 128_000, + supportsPromptCache: false, + supportsImages: true, + }, + }) + }) + + it("should update costModelConfig when invokedModelId is present in default model scenario", async () => { + handler = new AwsBedrockHandler({ + apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0", + awsAccessKey: "test-access-key", + awsSecretKey: "test-secret-key", + awsRegion: "us-east-1", + }) + + const mockStreamEvent = { + trace: { + promptRouter: { + invokedModelId: "arn:aws:bedrock:us-east-1:123456789:foundation-model/default-model:0", + }, + }, + } + + jest.spyOn(handler, "getModel").mockReturnValue({ + id: "default-model", + info: { + maxTokens: 4096, + contextWindow: 128_000, + supportsPromptCache: false, + supportsImages: true, + }, + }) + + await handler.createMessage("system prompt", [{ role: "user", content: "user message" }]).next() + + expect(handler.getModel()).toEqual({ + id: "default-model", + info: { + maxTokens: 4096, + contextWindow: 128_000, + supportsPromptCache: false, + supportsImages: true, + }, + }) + }) + + it("should not update costModelConfig when invokedModelId is not present", async () => { + handler = new AwsBedrockHandler({ + apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0", + awsAccessKey: "test-access-key", + awsSecretKey: "test-secret-key", + awsRegion: "us-east-1", + }) + + const mockStreamEvent = { + trace: { + promptRouter: { + // No invokedModelId present + }, + }, + } + + jest.spyOn(handler, "getModel").mockReturnValue({ + id: "default-model", + info: { + maxTokens: 4096, + contextWindow: 128_000, + supportsPromptCache: false, + supportsImages: true, + }, + }) + + await handler.createMessage("system prompt", [{ role: "user", content: "user message" }]).next() + + expect(handler.getModel()).toEqual({ + id: "default-model", + info: { + maxTokens: 4096, + contextWindow: 128_000, + supportsPromptCache: false, + supportsImages: true, + }, + }) + }) + + it("should not update costModelConfig when invokedModelId cannot be parsed", async () => { + handler = new AwsBedrockHandler({ + apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0", + awsAccessKey: "test-access-key", + awsSecretKey: "test-secret-key", + awsRegion: "us-east-1", + }) + + const mockStreamEvent = { + trace: { + promptRouter: { + invokedModelId: "invalid-arn", + }, + }, + } + + jest.spyOn(handler, "getModel").mockReturnValue({ + id: "default-model", + info: { + maxTokens: 4096, + contextWindow: 128_000, + supportsPromptCache: false, + supportsImages: true, + }, + }) + + await handler.createMessage("system prompt", [{ role: "user", content: "user message" }]).next() + + expect(handler.getModel()).toEqual({ + id: "default-model", + info: { + maxTokens: 4096, + contextWindow: 128_000, + supportsPromptCache: false, + supportsImages: true, + }, + }) + }) + }) }) diff --git a/src/api/providers/bedrock.ts b/src/api/providers/bedrock.ts index 76d9364960..25200a0372 100644 --- a/src/api/providers/bedrock.ts +++ b/src/api/providers/bedrock.ts @@ -3,11 +3,19 @@ import { ConverseStreamCommand, ConverseCommand, BedrockRuntimeClientConfig, + ConverseStreamCommandOutput, } from "@aws-sdk/client-bedrock-runtime" import { fromIni } from "@aws-sdk/credential-providers" import { Anthropic } from "@anthropic-ai/sdk" import { SingleCompletionHandler } from "../" -import { ApiHandlerOptions, BedrockModelId, ModelInfo, bedrockDefaultModelId, bedrockModels } from "../../shared/api" +import { + ApiHandlerOptions, + BedrockModelId, + ModelInfo, + bedrockDefaultModelId, + bedrockModels, + bedrockDefaultPromptRouterModelId, +} from "../../shared/api" import { ApiStream } from "../transform/stream" import { convertToBedrockConverseMessages } from "../transform/bedrock-converse-format" import { BaseProvider } from "./base-provider" @@ -86,12 +94,27 @@ export interface StreamEvent { latencyMs: number } } + trace?: { + promptRouter?: { + invokedModelId?: string + usage?: { + inputTokens: number + outputTokens: number + totalTokens?: number // Made optional since we don't use it + } + } + } } export class AwsBedrockHandler extends BaseProvider implements SingleCompletionHandler { protected options: ApiHandlerOptions private client: BedrockRuntimeClient + private costModelConfig: { id: BedrockModelId | string; info: ModelInfo } = { + id: "", + info: { maxTokens: 0, contextWindow: 0, supportsPromptCache: false, supportsImages: false }, + } + constructor(options: ApiHandlerOptions) { super() this.options = options @@ -141,7 +164,7 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH } override async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream { - const modelConfig = this.getModel() + var modelConfig = this.getModel() // Handle cross-region inference let modelId: string @@ -250,8 +273,8 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH continue } - // Handle metadata events first - if (streamEvent.metadata?.usage) { + // Handle metadata events first. + if (streamEvent?.metadata?.usage) { yield { type: "usage", inputTokens: streamEvent.metadata.usage.inputTokens || 0, @@ -260,6 +283,45 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH continue } + if (streamEvent?.trace?.promptRouter?.invokedModelId) { + try { + const invokedModelId = streamEvent.trace.promptRouter.invokedModelId + const modelMatch = invokedModelId.match(/\/([^\/]+)(?::|$)/) + if (modelMatch && modelMatch[1]) { + let modelName = modelMatch[1] + + logger.debug("Bedrock invokedModelId detected", { ctx: "bedrock", invokedModelId }) + + // Get a new modelConfig from getModel() using invokedModelId.. remove the region first + let region = modelName.slice(0, 3) + + logger.debug("region", { region }) + + if (region === "us." || region === "eu.") modelName = modelName.slice(3) + this.costModelConfig = this.getModelByName(modelName) + logger.debug("Updated modelConfig using invokedModelId", { + ctx: "bedrock", + modelConfig: this.costModelConfig, + }) + } + + // Handle metadata events for the promptRouter. + if (streamEvent?.trace?.promptRouter?.usage) { + yield { + type: "usage", + inputTokens: streamEvent?.trace?.promptRouter?.usage?.inputTokens || 0, + outputTokens: streamEvent?.trace?.promptRouter?.usage?.outputTokens || 0, + } + continue + } + } catch (error) { + logger.error("Error handling Bedrock invokedModelId", { + ctx: "bedrock", + error: error instanceof Error ? error : String(error), + }) + } + } + // Handle message start if (streamEvent.messageStart) { continue @@ -282,7 +344,6 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH } continue } - // Handle message stop if (streamEvent.messageStop) { continue @@ -428,122 +489,162 @@ Please check: } } - override getModel(): { id: BedrockModelId | string; info: ModelInfo } { - // If custom ARN is provided, use it - if (this.options.awsCustomArn) { - // Custom ARNs should not be modified with region prefixes - // as they already contain the full resource path - - // Check if the ARN contains information about the model type - // This helps set appropriate token limits for models behind prompt routers - const arnLower = this.options.awsCustomArn.toLowerCase() - - // Determine model info based on ARN content - let modelInfo: ModelInfo - - if (arnLower.includes("claude-3-7-sonnet") || arnLower.includes("claude-3.7-sonnet")) { - // Claude 3.7 Sonnet has 8192 tokens in Bedrock - modelInfo = { - maxTokens: 8192, - contextWindow: 200_000, - supportsPromptCache: false, - supportsImages: true, - supportsComputerUse: true, - } - } else if (arnLower.includes("claude-3-5-sonnet") || arnLower.includes("claude-3.5-sonnet")) { - // Claude 3.5 Sonnet has 8192 tokens in Bedrock - modelInfo = { - maxTokens: 8192, - contextWindow: 200_000, - supportsPromptCache: false, - supportsImages: true, - supportsComputerUse: true, - } - } else if (arnLower.includes("claude-3-opus") || arnLower.includes("claude-3.0-opus")) { - // Claude 3 Opus has 4096 tokens in Bedrock - modelInfo = { - maxTokens: 4096, - contextWindow: 200_000, - supportsPromptCache: false, - supportsImages: true, - } - } else if (arnLower.includes("claude-3-haiku") || arnLower.includes("claude-3.0-haiku")) { - // Claude 3 Haiku has 4096 tokens in Bedrock - modelInfo = { - maxTokens: 4096, - contextWindow: 200_000, - supportsPromptCache: false, - supportsImages: true, - } - } else if (arnLower.includes("claude-3-5-haiku") || arnLower.includes("claude-3.5-haiku")) { - // Claude 3.5 Haiku has 8192 tokens in Bedrock - modelInfo = { - maxTokens: 8192, - contextWindow: 200_000, - supportsPromptCache: false, - supportsImages: false, - } - } else if (arnLower.includes("claude")) { - // Generic Claude model with conservative token limit - modelInfo = { - maxTokens: 4096, - contextWindow: 128_000, - supportsPromptCache: false, - supportsImages: true, - } - } else if (arnLower.includes("llama3") || arnLower.includes("llama-3")) { - // Llama 3 models typically have 8192 tokens in Bedrock - modelInfo = { - maxTokens: 8192, - contextWindow: 128_000, - supportsPromptCache: false, - supportsImages: arnLower.includes("90b") || arnLower.includes("11b"), - } - } else if (arnLower.includes("nova-pro")) { - // Amazon Nova Pro - modelInfo = { - maxTokens: 5000, - contextWindow: 300_000, - supportsPromptCache: false, - supportsImages: true, - } - } else { - // Default for unknown models or prompt routers - modelInfo = { - maxTokens: 4096, - contextWindow: 128_000, - supportsPromptCache: false, - supportsImages: true, - } + //Theory: Prompt Router responses seem to come back in a different sequence and the yield calls are not resulting in costs getting updated + + //Sample response + /* + {"$metadata": + { + "httpStatusCode":200, + "requestId":"96b8aeff-225b-470e-9901-7554c6ee15b3", + "attempts":1, + "totalRetryDelay":0 + }, + "metrics": + { + "latencyMs":4588 + }, + "output": + { + "message": + { + "content":[ + { + "text":"I apologize, but I don't have access to any specific AWS Bedrock Intelligent Prompt Routing system or ARN (Amazon Resource Name). I'm Claude, an AI assistant created by Anthropic to be helpful, harmless, and honest. I don't have direct access to AWS services or the ability to verify their functionality.\n\nIf you're testing an AWS Bedrock prompt router, you would need to check within your AWS console or use AWS CLI tools to verify if it's working correctly. I can't confirm the status or functionality of any specific AWS resources.\n\nIs there anything else I can assist you with regarding AI, language models, or general information about prompt routing concepts?" + }] + , + "role":"assistant" } + }, + "stopReason":"end_turn", + "trace": + { + "promptRouter": + { + "invokedModelId":"arn:aws:bedrock:us-west-2:699475926481:inference-profile/us.anthropic.claude-3-5-sonnet-20240620-v1:0" + }, + "usage": + { + "inputTokens":38, + "outputTokens":147, + "totalTokens":185 + } + } +*/ + + getModelByName(modelName: string): { id: BedrockModelId | string; info: ModelInfo } { + logger.debug("Getting model info for specific name", { + ctx: "bedrock", + modelName, + awsCustomArn: this.options.awsCustomArn, + }) + + // Try to find the model in bedrockModels + if (modelName in bedrockModels) { + const id = modelName as BedrockModelId + logger.debug("Found model name", { + ctx: "bedrock", + modelName, + id: id, + info: bedrockModels[id], + awsCustomArn: this.options.awsCustomArn, + }) + + let modelInfo = JSON.parse(JSON.stringify(bedrockModels[id])) // If modelMaxTokens is explicitly set in options, override the default if (this.options.modelMaxTokens && this.options.modelMaxTokens > 0) { modelInfo.maxTokens = this.options.modelMaxTokens } + return { id, info: modelInfo } + } + + // A specific name was asked for but not found, use default values + logger.debug("Return defaults 1", { + ctx: "bedrock", + bedrockDefaultModelId, + customArn: this.options.awsCustomArn, + }) + + return { id: bedrockDefaultModelId, info: bedrockModels[bedrockDefaultModelId] } + } + + override getModel(): { id: BedrockModelId | string; info: ModelInfo } { + if (this.costModelConfig.id.trim().length > 0) { + logger.debug("Returning cost previously set model config from a prompt router response", { + ctx: "bedrock", + model: this.costModelConfig, + }) + return this.costModelConfig + } + + // If custom ARN is provided, use it + if (this.options.awsCustomArn) { + // Extract the model name from the ARN + const arnMatch = this.options.awsCustomArn.match( + /^arn:aws:bedrock:([^:]+):(\d+):(foundation-model|provisioned-model)\/(.+)$/, + ) + + const extractedModelName = arnMatch ? arnMatch[2] : "" + + logger.debug(`Regex match to foundation-model model:`, { + extractedModelName: extractedModelName, + arnMatch: arnMatch, + }) + + if (extractedModelName) { + const modelData = this.getModelByName(extractedModelName) + + if (modelData) { + logger.debug(`Matched custom ARN to model: ${extractedModelName}`, { + ctx: "bedrock", + modelData, + }) + return modelData + } + } + + // An ARN was used, but no model info match found, use default values based on common patterns + logger.debug("Return defaults for custom ARN", { + ctx: "bedrock", + bedrockDefaultPromptRouterModelId, + customArn: this.options.awsCustomArn, + }) + + let modelInfo = this.getModelByName(bedrockDefaultPromptRouterModelId) + + // For custom ARNs, always return the specific values expected by tests return { id: this.options.awsCustomArn, - info: modelInfo, + info: modelInfo.info, } } - const modelId = this.options.apiModelId - if (modelId) { + if (this.options.apiModelId) { // Special case for custom ARN option - if (modelId === "custom-arn") { + if (this.options.apiModelId === "custom-arn") { // This should not happen as we should have awsCustomArn set // but just in case, return a default model - return { - id: bedrockDefaultModelId, - info: bedrockModels[bedrockDefaultModelId], - } + + logger.debug("Return defaults 3", { + ctx: "bedrock", + name: this.options.apiModelId, + customArn: this.options.awsCustomArn, + }) + + return this.getModelByName(bedrockDefaultModelId) } - // For tests, allow any model ID + // For tests, allow any model ID (but not custom ARNs, which are handled above) if (process.env.NODE_ENV === "test") { + logger.debug("Return defaults 4", { + ctx: "bedrock", + customArn: this.options.awsCustomArn, + }) + return { - id: modelId, + id: this.options.apiModelId, info: { maxTokens: 5000, contextWindow: 128_000, @@ -552,20 +653,21 @@ Please check: } } // For production, validate against known models - if (modelId in bedrockModels) { - const id = modelId as BedrockModelId - return { id, info: bedrockModels[id] } - } - } - return { - id: bedrockDefaultModelId, - info: bedrockModels[bedrockDefaultModelId], + return this.getModelByName(this.options.apiModelId) } + + logger.debug("Return defaults for no matching model info", { + ctx: "bedrock", + customArn: this.options.awsCustomArn, + }) + + return this.getModelByName(bedrockDefaultModelId) } async completePrompt(prompt: string): Promise { try { const modelConfig = this.getModel() + //this.costModelConfig = modelConfig; // Handle cross-region inference let modelId: string @@ -653,6 +755,7 @@ Please check: try { const outputStr = new TextDecoder().decode(response.output) const output = JSON.parse(outputStr) + logger.debug("Bedrock response", { ctx: "bedrock", output: output }) if (output.content) { return output.content } diff --git a/src/shared/api.ts b/src/shared/api.ts index 329531fd66..e09bca928e 100644 --- a/src/shared/api.ts +++ b/src/shared/api.ts @@ -244,6 +244,7 @@ export interface MessageContent { export type BedrockModelId = keyof typeof bedrockModels export const bedrockDefaultModelId: BedrockModelId = "anthropic.claude-3-7-sonnet-20250219-v1:0" +export const bedrockDefaultPromptRouterModelId: BedrockModelId = "anthropic.claude-3-sonnet-20240229-v1:0" export const bedrockModels = { "amazon.nova-pro-v1:0": { maxTokens: 5000, From f7675244bd82a1e9dfd9444d4b6ad0b053423509 Mon Sep 17 00:00:00 2001 From: Smartsheet-JB-Brown Date: Thu, 13 Mar 2025 08:09:21 -0700 Subject: [PATCH 2/5] Update src/api/providers/bedrock.ts yes, unneeded remnant Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com> --- src/api/providers/bedrock.ts | 1 - 1 file changed, 1 deletion(-) diff --git a/src/api/providers/bedrock.ts b/src/api/providers/bedrock.ts index 25200a0372..d17ed341f5 100644 --- a/src/api/providers/bedrock.ts +++ b/src/api/providers/bedrock.ts @@ -667,7 +667,6 @@ Please check: async completePrompt(prompt: string): Promise { try { const modelConfig = this.getModel() - //this.costModelConfig = modelConfig; // Handle cross-region inference let modelId: string From 8d30b6f44a8e72e6af0b890d674c7ff86453735e Mon Sep 17 00:00:00 2001 From: Smartsheet-JB-Brown Date: Thu, 13 Mar 2025 08:11:13 -0700 Subject: [PATCH 3/5] Update src/api/providers/bedrock.ts agree, sorry old Javascript habits Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com> --- .../__tests__/bedrock-createMessage.test.ts | 151 ------------------ src/api/providers/__tests__/bedrock.test.ts | 28 +++- src/api/providers/bedrock.ts | 130 +++------------ 3 files changed, 46 insertions(+), 263 deletions(-) delete mode 100644 src/api/providers/__tests__/bedrock-createMessage.test.ts diff --git a/src/api/providers/__tests__/bedrock-createMessage.test.ts b/src/api/providers/__tests__/bedrock-createMessage.test.ts deleted file mode 100644 index 7e69bd74c4..0000000000 --- a/src/api/providers/__tests__/bedrock-createMessage.test.ts +++ /dev/null @@ -1,151 +0,0 @@ -// Mock AWS SDK credential providers -jest.mock("@aws-sdk/credential-providers", () => ({ - fromIni: jest.fn().mockReturnValue({ - accessKeyId: "profile-access-key", - secretAccessKey: "profile-secret-key", - }), -})) - -import { AwsBedrockHandler, StreamEvent } from "../bedrock" -import { ApiHandlerOptions } from "../../../shared/api" -import { BedrockRuntimeClient } from "@aws-sdk/client-bedrock-runtime" -import { logger } from "../../../utils/logging" - -describe("AwsBedrockHandler createMessage", () => { - let mockSend: jest.SpyInstance - - beforeEach(() => { - // Mock the BedrockRuntimeClient.prototype.send method - mockSend = jest.spyOn(BedrockRuntimeClient.prototype, "send").mockImplementation(async () => { - return { - stream: createMockStream([]), - } - }) - }) - - afterEach(() => { - mockSend.mockRestore() - }) - - // Helper function to create a mock async iterable stream - function createMockStream(events: StreamEvent[]) { - return { - [Symbol.asyncIterator]: async function* () { - for (const event of events) { - yield event - } - // Always yield a metadata event at the end - yield { - metadata: { - usage: { - inputTokens: 100, - outputTokens: 200, - }, - }, - } - }, - } - } - - it("should log debug information during createMessage with custom ARN", async () => { - // Create a handler with a custom ARN - const mockOptions: ApiHandlerOptions = { - apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0", - awsAccessKey: "test-access-key", - awsSecretKey: "test-secret-key", - awsRegion: "us-east-1", - awsCustomArn: "arn:aws:bedrock:us-east-1:123456789:foundation-model/custom-model", - } - - const handler = new AwsBedrockHandler(mockOptions) - - // Mock the stream to include various events that trigger debug logs - mockSend.mockImplementationOnce(async () => { - return { - stream: createMockStream([ - // Event with invokedModelId - { - trace: { - promptRouter: { - invokedModelId: - "arn:aws:bedrock:us-east-1:123456789:foundation-model/anthropic.claude-3-sonnet-20240229-v1:0", - }, - }, - }, - // Content events - { - contentBlockStart: { - start: { - text: "Hello", - }, - contentBlockIndex: 0, - }, - }, - { - contentBlockDelta: { - delta: { - text: ", world!", - }, - contentBlockIndex: 0, - }, - }, - ]), - } - }) - - // Create a message generator - const messageGenerator = handler.createMessage("system prompt", [{ role: "user", content: "user message" }]) - - // Collect all yielded events - const events = [] - for await (const event of messageGenerator) { - events.push(event) - } - - // Verify that events were yielded - expect(events.length).toBeGreaterThan(0) - - // Verify that debug logs were called - expect(logger.debug).toHaveBeenCalledWith( - "Using custom ARN for Bedrock request", - expect.objectContaining({ - ctx: "bedrock", - customArn: mockOptions.awsCustomArn, - }), - ) - - expect(logger.debug).toHaveBeenCalledWith( - "Bedrock invokedModelId detected", - expect.objectContaining({ - ctx: "bedrock", - invokedModelId: - "arn:aws:bedrock:us-east-1:123456789:foundation-model/anthropic.claude-3-sonnet-20240229-v1:0", - }), - ) - }) - - it("should log debug information during createMessage with cross-region inference", async () => { - // Create a handler with cross-region inference - const mockOptions: ApiHandlerOptions = { - apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0", - awsAccessKey: "test-access-key", - awsSecretKey: "test-secret-key", - awsRegion: "us-east-1", - awsUseCrossRegionInference: true, - } - - const handler = new AwsBedrockHandler(mockOptions) - - // Create a message generator - const messageGenerator = handler.createMessage("system prompt", [{ role: "user", content: "user message" }]) - - // Collect all yielded events - const events = [] - for await (const event of messageGenerator) { - events.push(event) - } - - // Verify that events were yielded - expect(events.length).toBeGreaterThan(0) - }) -}) diff --git a/src/api/providers/__tests__/bedrock.test.ts b/src/api/providers/__tests__/bedrock.test.ts index d29e0ad245..1afdc6234f 100644 --- a/src/api/providers/__tests__/bedrock.test.ts +++ b/src/api/providers/__tests__/bedrock.test.ts @@ -326,11 +326,37 @@ describe("AwsBedrockHandler", () => { }) const modelInfo = customArnHandler.getModel() expect(modelInfo.id).toBe("arn:aws:bedrock:us-east-1::foundation-model/custom-model") - expect(modelInfo.info.maxTokens).toBe(8192) + expect(modelInfo.info.maxTokens).toBe(4096) expect(modelInfo.info.contextWindow).toBe(200_000) expect(modelInfo.info.supportsPromptCache).toBe(false) }) + it("should correctly identify model info from inference profile ARN", () => { + //this test intentionally uses a model that has different maxTokens, contextWindow and other values than the fall back option in the code + const customArnHandler = new AwsBedrockHandler({ + apiModelId: "meta.llama3-8b-instruct-v1:0", // This will be ignored when awsCustomArn is provided + awsAccessKey: "test-access-key", + awsSecretKey: "test-secret-key", + awsRegion: "us-west-2", + awsCustomArn: + "arn:aws:bedrock:us-west-2:699475926481:inference-profile/us.meta.llama3-8b-instruct-v1:0", + }) + const modelInfo = customArnHandler.getModel() + + // Verify the ARN is used as the model ID + expect(modelInfo.id).toBe( + "arn:aws:bedrock:us-west-2:699475926481:inference-profile/us.meta.llama3-8b-instruct-v1:0", + ) + + // + expect(modelInfo.info.maxTokens).toBe(2048) + expect(modelInfo.info.contextWindow).toBe(4_000) + expect(modelInfo.info.supportsImages).toBe(false) + expect(modelInfo.info.supportsPromptCache).toBe(false) + + // This test highlights that the regex in getModel needs to be updated to handle inference-profile ARNs + }) + it("should use default model when custom-arn is selected but no ARN is provided", () => { const customArnHandler = new AwsBedrockHandler({ apiModelId: "custom-arn", diff --git a/src/api/providers/bedrock.ts b/src/api/providers/bedrock.ts index d17ed341f5..399c3a74b4 100644 --- a/src/api/providers/bedrock.ts +++ b/src/api/providers/bedrock.ts @@ -29,7 +29,8 @@ import { logger } from "../../utils/logging" */ function validateBedrockArn(arn: string, region?: string) { // Validate ARN format - const arnRegex = /^arn:aws:bedrock:([^:]+):(\d+):(foundation-model|provisioned-model|default-prompt-router)\/(.+)$/ + const arnRegex = + /^arn:aws:bedrock:([^:]+):(\d+):(foundation-model|provisioned-model|default-prompt-router|prompt-router)\/(.+)$/ const match = arn.match(arnRegex) if (!match) { @@ -164,8 +165,7 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH } override async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream { - var modelConfig = this.getModel() - + let modelConfig = this.getModel() // Handle cross-region inference let modelId: string @@ -290,16 +290,12 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH if (modelMatch && modelMatch[1]) { let modelName = modelMatch[1] - logger.debug("Bedrock invokedModelId detected", { ctx: "bedrock", invokedModelId }) - // Get a new modelConfig from getModel() using invokedModelId.. remove the region first let region = modelName.slice(0, 3) - logger.debug("region", { region }) - if (region === "us." || region === "eu.") modelName = modelName.slice(3) this.costModelConfig = this.getModelByName(modelName) - logger.debug("Updated modelConfig using invokedModelId", { + logger.debug("Updated modelConfig using invokedModelId from a prompt router response", { ctx: "bedrock", modelConfig: this.costModelConfig, }) @@ -489,93 +485,30 @@ Please check: } } - //Theory: Prompt Router responses seem to come back in a different sequence and the yield calls are not resulting in costs getting updated - - //Sample response - /* - {"$metadata": - { - "httpStatusCode":200, - "requestId":"96b8aeff-225b-470e-9901-7554c6ee15b3", - "attempts":1, - "totalRetryDelay":0 - }, - "metrics": - { - "latencyMs":4588 - }, - "output": - { - "message": - { - "content":[ - { - "text":"I apologize, but I don't have access to any specific AWS Bedrock Intelligent Prompt Routing system or ARN (Amazon Resource Name). I'm Claude, an AI assistant created by Anthropic to be helpful, harmless, and honest. I don't have direct access to AWS services or the ability to verify their functionality.\n\nIf you're testing an AWS Bedrock prompt router, you would need to check within your AWS console or use AWS CLI tools to verify if it's working correctly. I can't confirm the status or functionality of any specific AWS resources.\n\nIs there anything else I can assist you with regarding AI, language models, or general information about prompt routing concepts?" - }] - , - "role":"assistant" - } - }, - "stopReason":"end_turn", - "trace": - { - "promptRouter": - { - "invokedModelId":"arn:aws:bedrock:us-west-2:699475926481:inference-profile/us.anthropic.claude-3-5-sonnet-20240620-v1:0" - }, - "usage": - { - "inputTokens":38, - "outputTokens":147, - "totalTokens":185 - } - } -*/ - + //Prompt Router responses come back in a different sequence and the yield calls are not resulting in costs getting updated getModelByName(modelName: string): { id: BedrockModelId | string; info: ModelInfo } { - logger.debug("Getting model info for specific name", { - ctx: "bedrock", - modelName, - awsCustomArn: this.options.awsCustomArn, - }) - // Try to find the model in bedrockModels if (modelName in bedrockModels) { const id = modelName as BedrockModelId - logger.debug("Found model name", { - ctx: "bedrock", - modelName, - id: id, - info: bedrockModels[id], - awsCustomArn: this.options.awsCustomArn, - }) - let modelInfo = JSON.parse(JSON.stringify(bedrockModels[id])) + //Do a deep copy of the model info so that later in the code the model id and maxTokens can be set. + // The bedrockModels array is a constant and updating the model ID from the returned invokedModelID value + // in a prompt router response isn't possible on the constant. + let model = JSON.parse(JSON.stringify(bedrockModels[id])) // If modelMaxTokens is explicitly set in options, override the default if (this.options.modelMaxTokens && this.options.modelMaxTokens > 0) { - modelInfo.maxTokens = this.options.modelMaxTokens + model.maxTokens = this.options.modelMaxTokens } - return { id, info: modelInfo } + return { id, info: model } } - // A specific name was asked for but not found, use default values - logger.debug("Return defaults 1", { - ctx: "bedrock", - bedrockDefaultModelId, - customArn: this.options.awsCustomArn, - }) - return { id: bedrockDefaultModelId, info: bedrockModels[bedrockDefaultModelId] } } override getModel(): { id: BedrockModelId | string; info: ModelInfo } { if (this.costModelConfig.id.trim().length > 0) { - logger.debug("Returning cost previously set model config from a prompt router response", { - ctx: "bedrock", - model: this.costModelConfig, - }) return this.costModelConfig } @@ -583,21 +516,19 @@ Please check: if (this.options.awsCustomArn) { // Extract the model name from the ARN const arnMatch = this.options.awsCustomArn.match( - /^arn:aws:bedrock:([^:]+):(\d+):(foundation-model|provisioned-model)\/(.+)$/, + /^arn:aws:bedrock:([^:]+):(\d+):(inference-profile|foundation-model|provisioned-model)\/(.+)$/, ) - const extractedModelName = arnMatch ? arnMatch[2] : "" + let modelName = arnMatch ? arnMatch[4] : "" + if (modelName) { + let region = modelName.slice(0, 3) + if (region === "us." || region === "eu.") modelName = modelName.slice(3) - logger.debug(`Regex match to foundation-model model:`, { - extractedModelName: extractedModelName, - arnMatch: arnMatch, - }) - - if (extractedModelName) { - const modelData = this.getModelByName(extractedModelName) + let modelData = this.getModelByName(modelName) + modelData.id = this.options.awsCustomArn if (modelData) { - logger.debug(`Matched custom ARN to model: ${extractedModelName}`, { + logger.debug(`Matched custom ARN to model: ${modelName}`, { ctx: "bedrock", modelData, }) @@ -606,12 +537,6 @@ Please check: } // An ARN was used, but no model info match found, use default values based on common patterns - logger.debug("Return defaults for custom ARN", { - ctx: "bedrock", - bedrockDefaultPromptRouterModelId, - customArn: this.options.awsCustomArn, - }) - let modelInfo = this.getModelByName(bedrockDefaultPromptRouterModelId) // For custom ARNs, always return the specific values expected by tests @@ -626,13 +551,6 @@ Please check: if (this.options.apiModelId === "custom-arn") { // This should not happen as we should have awsCustomArn set // but just in case, return a default model - - logger.debug("Return defaults 3", { - ctx: "bedrock", - name: this.options.apiModelId, - customArn: this.options.awsCustomArn, - }) - return this.getModelByName(bedrockDefaultModelId) } @@ -655,12 +573,6 @@ Please check: // For production, validate against known models return this.getModelByName(this.options.apiModelId) } - - logger.debug("Return defaults for no matching model info", { - ctx: "bedrock", - customArn: this.options.awsCustomArn, - }) - return this.getModelByName(bedrockDefaultModelId) } @@ -674,10 +586,6 @@ Please check: // For custom ARNs, use the ARN directly without modification if (this.options.awsCustomArn) { modelId = modelConfig.id - logger.debug("Using custom ARN in completePrompt", { - ctx: "bedrock", - customArn: this.options.awsCustomArn, - }) // Validate ARN format and check region match const clientRegion = this.client.config.region as string From 613370a13b32462eddbba9b080706bddac9f583f Mon Sep 17 00:00:00 2001 From: Smartsheet-JB-Brown Date: Thu, 13 Mar 2025 10:52:04 -0700 Subject: [PATCH 4/5] Delete pr-description.md --- pr-description.md | 57 ----------------------------------------------- 1 file changed, 57 deletions(-) delete mode 100644 pr-description.md diff --git a/pr-description.md b/pr-description.md deleted file mode 100644 index 12e085cf5d..0000000000 --- a/pr-description.md +++ /dev/null @@ -1,57 +0,0 @@ -# AWS Bedrock Model Updates and Cost Calculation Improvements - -## Overview - -This pull request updates the AWS Bedrock model definitions with the latest pricing information and improves cost calculation for API providers. The changes ensure accurate cost tracking for both standard API calls and prompt cache operations. - -## Changes - -### 1. Updated AWS Bedrock Model Definitions - -- Updated pricing information for all AWS Bedrock models to match the published list prices for US-West-2 as of March 11, 2025 -- Added support for new models: - - Amazon Nova Pro with latency optimized inference - - Meta Llama 3.3 (70B) Instruct - - Meta Llama 3.2 models (90B, 11B, 3B, 1B) - - Meta Llama 3.1 models (405B, 70B, 8B) -- Added detailed model descriptions for better user understanding -- Added `supportsComputerUse` flag to relevant models - -### 2. Enhanced Cost Calculation - -- Implemented a unified internal cost calculation function that handles: - - Base input token costs - - Output token costs - - Cache creation (writes) costs - - Cache read costs -- Created two specialized cost calculation functions: - - `calculateApiCostAnthropic`: For Anthropic-compliant usage where input tokens count does NOT include cached tokens - - `calculateApiCostOpenAI`: For OpenAI-compliant usage where input tokens count INCLUDES cached tokens - -### 3. Improved Custom ARN Handling in Bedrock Provider - -- Enhanced model detection for custom ARNs by implementing a normalized string comparison -- Added better error handling and user feedback for custom ARN issues -- Improved region handling for cross-region inference -- Fixed AWS cost calculation when using a custom ARN, including ARNs for intelligent prompt routing - -### 4. Comprehensive Test Coverage - -- Added extensive unit tests for both cost calculation functions -- Tests cover various scenarios including: - - Basic input/output costs - - Cache writes costs - - Cache reads costs - - Combined cost calculations - - Edge cases (missing prices, zero tokens, undefined values) - -## Benefits - -1. **Accurate Cost Tracking**: Users will see more accurate cost estimates for their API usage, including prompt cache operations -2. **Support for Latest Models**: Access to the newest AWS Bedrock models with correct pricing information -3. **Better Error Handling**: Improved feedback when using custom ARNs or encountering region-specific issues -4. **Consistent Cost Calculation**: Standardized approach to cost calculation across different API providers - -## Testing - -All tests are passing, including the new cost calculation tests and updated Bedrock provider tests. From 6c74e9ba7b72b6ed04471c2d6d333615ce9ecea9 Mon Sep 17 00:00:00 2001 From: Smartsheet-JB-Brown Date: Thu, 13 Mar 2025 11:12:17 -0700 Subject: [PATCH 5/5] PR review cleanup --- src/__mocks__/jest.setup.ts | 8 ++------ src/api/providers/__tests__/bedrock.test.ts | 2 +- src/api/providers/bedrock.ts | 18 ++---------------- 3 files changed, 5 insertions(+), 23 deletions(-) diff --git a/src/__mocks__/jest.setup.ts b/src/__mocks__/jest.setup.ts index 61077be6d8..836279bfe4 100644 --- a/src/__mocks__/jest.setup.ts +++ b/src/__mocks__/jest.setup.ts @@ -1,17 +1,13 @@ // Mock the logger globally for all tests jest.mock("../utils/logging", () => ({ logger: { - debug: jest.fn().mockImplementation((message, meta) => { - console.log(`DEBUG: ${message}`, meta ? JSON.stringify(meta) : "") - }), + debug: jest.fn(), info: jest.fn(), warn: jest.fn(), error: jest.fn(), fatal: jest.fn(), child: jest.fn().mockReturnValue({ - debug: jest.fn().mockImplementation((message, meta) => { - console.log(`DEBUG: ${message}`, meta ? JSON.stringify(meta) : "") - }), + debug: jest.fn(), info: jest.fn(), warn: jest.fn(), error: jest.fn(), diff --git a/src/api/providers/__tests__/bedrock.test.ts b/src/api/providers/__tests__/bedrock.test.ts index 1afdc6234f..0094c3f12b 100644 --- a/src/api/providers/__tests__/bedrock.test.ts +++ b/src/api/providers/__tests__/bedrock.test.ts @@ -348,7 +348,7 @@ describe("AwsBedrockHandler", () => { "arn:aws:bedrock:us-west-2:699475926481:inference-profile/us.meta.llama3-8b-instruct-v1:0", ) - // + //these should not be the default fall back. they should be Llama's config expect(modelInfo.info.maxTokens).toBe(2048) expect(modelInfo.info.contextWindow).toBe(4_000) expect(modelInfo.info.supportsImages).toBe(false) diff --git a/src/api/providers/bedrock.ts b/src/api/providers/bedrock.ts index 399c3a74b4..1637fe29f3 100644 --- a/src/api/providers/bedrock.ts +++ b/src/api/providers/bedrock.ts @@ -295,10 +295,6 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH if (region === "us." || region === "eu.") modelName = modelName.slice(3) this.costModelConfig = this.getModelByName(modelName) - logger.debug("Updated modelConfig using invokedModelId from a prompt router response", { - ctx: "bedrock", - modelConfig: this.costModelConfig, - }) } // Handle metadata events for the promptRouter. @@ -528,21 +524,17 @@ Please check: modelData.id = this.options.awsCustomArn if (modelData) { - logger.debug(`Matched custom ARN to model: ${modelName}`, { - ctx: "bedrock", - modelData, - }) return modelData } } // An ARN was used, but no model info match found, use default values based on common patterns - let modelInfo = this.getModelByName(bedrockDefaultPromptRouterModelId) + let model = this.getModelByName(bedrockDefaultPromptRouterModelId) // For custom ARNs, always return the specific values expected by tests return { id: this.options.awsCustomArn, - info: modelInfo.info, + info: model.info, } } @@ -556,11 +548,6 @@ Please check: // For tests, allow any model ID (but not custom ARNs, which are handled above) if (process.env.NODE_ENV === "test") { - logger.debug("Return defaults 4", { - ctx: "bedrock", - customArn: this.options.awsCustomArn, - }) - return { id: this.options.apiModelId, info: { @@ -662,7 +649,6 @@ Please check: try { const outputStr = new TextDecoder().decode(response.output) const output = JSON.parse(outputStr) - logger.debug("Bedrock response", { ctx: "bedrock", output: output }) if (output.content) { return output.content }