-
Notifications
You must be signed in to change notification settings - Fork 2.6k
Cost display updating for Bedrock custom ARNs that are prompt routers #1604
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
d8df9a5
f767524
8d30b6f
613370a
6c74e9b
4ada518
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,57 @@ | ||
| # AWS Bedrock Model Updates and Cost Calculation Improvements | ||
|
|
||
| ## Overview | ||
|
|
||
| This pull request updates the AWS Bedrock model definitions with the latest pricing information and improves cost calculation for API providers. The changes ensure accurate cost tracking for both standard API calls and prompt cache operations. | ||
|
|
||
| ## Changes | ||
|
|
||
| ### 1. Updated AWS Bedrock Model Definitions | ||
|
|
||
| - Updated pricing information for all AWS Bedrock models to match the published list prices for US-West-2 as of March 11, 2025 | ||
| - Added support for new models: | ||
| - Amazon Nova Pro with latency optimized inference | ||
| - Meta Llama 3.3 (70B) Instruct | ||
| - Meta Llama 3.2 models (90B, 11B, 3B, 1B) | ||
| - Meta Llama 3.1 models (405B, 70B, 8B) | ||
| - Added detailed model descriptions for better user understanding | ||
| - Added `supportsComputerUse` flag to relevant models | ||
|
|
||
| ### 2. Enhanced Cost Calculation | ||
|
|
||
| - Implemented a unified internal cost calculation function that handles: | ||
| - Base input token costs | ||
| - Output token costs | ||
| - Cache creation (writes) costs | ||
| - Cache read costs | ||
| - Created two specialized cost calculation functions: | ||
| - `calculateApiCostAnthropic`: For Anthropic-compliant usage where input tokens count does NOT include cached tokens | ||
| - `calculateApiCostOpenAI`: For OpenAI-compliant usage where input tokens count INCLUDES cached tokens | ||
|
|
||
| ### 3. Improved Custom ARN Handling in Bedrock Provider | ||
|
|
||
| - Enhanced model detection for custom ARNs by implementing a normalized string comparison | ||
| - Added better error handling and user feedback for custom ARN issues | ||
| - Improved region handling for cross-region inference | ||
| - Fixed AWS cost calculation when using a custom ARN, including ARNs for intelligent prompt routing | ||
|
|
||
| ### 4. Comprehensive Test Coverage | ||
|
|
||
| - Added extensive unit tests for both cost calculation functions | ||
| - Tests cover various scenarios including: | ||
| - Basic input/output costs | ||
| - Cache writes costs | ||
| - Cache reads costs | ||
| - Combined cost calculations | ||
| - Edge cases (missing prices, zero tokens, undefined values) | ||
|
|
||
| ## Benefits | ||
|
|
||
| 1. **Accurate Cost Tracking**: Users will see more accurate cost estimates for their API usage, including prompt cache operations | ||
| 2. **Support for Latest Models**: Access to the newest AWS Bedrock models with correct pricing information | ||
| 3. **Better Error Handling**: Improved feedback when using custom ARNs or encountering region-specific issues | ||
| 4. **Consistent Cost Calculation**: Standardized approach to cost calculation across different API providers | ||
|
|
||
| ## Testing | ||
|
|
||
| All tests are passing, including the new cost calculation tests and updated Bedrock provider tests. | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,13 +1,17 @@ | ||
| // Mock the logger globally for all tests | ||
| jest.mock("../utils/logging", () => ({ | ||
| logger: { | ||
| debug: jest.fn(), | ||
| debug: jest.fn().mockImplementation((message, meta) => { | ||
|
||
| console.log(`DEBUG: ${message}`, meta ? JSON.stringify(meta) : "") | ||
| }), | ||
| info: jest.fn(), | ||
| warn: jest.fn(), | ||
| error: jest.fn(), | ||
| fatal: jest.fn(), | ||
| child: jest.fn().mockReturnValue({ | ||
| debug: jest.fn(), | ||
| debug: jest.fn().mockImplementation((message, meta) => { | ||
| console.log(`DEBUG: ${message}`, meta ? JSON.stringify(meta) : "") | ||
| }), | ||
|
||
| info: jest.fn(), | ||
| warn: jest.fn(), | ||
| error: jest.fn(), | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,151 @@ | ||
| // Mock AWS SDK credential providers | ||
| jest.mock("@aws-sdk/credential-providers", () => ({ | ||
| fromIni: jest.fn().mockReturnValue({ | ||
| accessKeyId: "profile-access-key", | ||
| secretAccessKey: "profile-secret-key", | ||
| }), | ||
| })) | ||
|
|
||
| import { AwsBedrockHandler, StreamEvent } from "../bedrock" | ||
| import { ApiHandlerOptions } from "../../../shared/api" | ||
| import { BedrockRuntimeClient } from "@aws-sdk/client-bedrock-runtime" | ||
| import { logger } from "../../../utils/logging" | ||
|
|
||
| describe("AwsBedrockHandler createMessage", () => { | ||
| let mockSend: jest.SpyInstance | ||
|
|
||
| beforeEach(() => { | ||
| // Mock the BedrockRuntimeClient.prototype.send method | ||
| mockSend = jest.spyOn(BedrockRuntimeClient.prototype, "send").mockImplementation(async () => { | ||
| return { | ||
| stream: createMockStream([]), | ||
| } | ||
| }) | ||
| }) | ||
|
|
||
| afterEach(() => { | ||
| mockSend.mockRestore() | ||
| }) | ||
|
|
||
| // Helper function to create a mock async iterable stream | ||
| function createMockStream(events: StreamEvent[]) { | ||
| return { | ||
| [Symbol.asyncIterator]: async function* () { | ||
| for (const event of events) { | ||
| yield event | ||
| } | ||
| // Always yield a metadata event at the end | ||
| yield { | ||
| metadata: { | ||
| usage: { | ||
| inputTokens: 100, | ||
| outputTokens: 200, | ||
| }, | ||
| }, | ||
| } | ||
| }, | ||
| } | ||
| } | ||
|
|
||
| it("should log debug information during createMessage with custom ARN", async () => { | ||
| // Create a handler with a custom ARN | ||
| const mockOptions: ApiHandlerOptions = { | ||
| apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0", | ||
| awsAccessKey: "test-access-key", | ||
| awsSecretKey: "test-secret-key", | ||
| awsRegion: "us-east-1", | ||
| awsCustomArn: "arn:aws:bedrock:us-east-1:123456789:foundation-model/custom-model", | ||
| } | ||
|
|
||
| const handler = new AwsBedrockHandler(mockOptions) | ||
|
|
||
| // Mock the stream to include various events that trigger debug logs | ||
| mockSend.mockImplementationOnce(async () => { | ||
| return { | ||
| stream: createMockStream([ | ||
| // Event with invokedModelId | ||
| { | ||
| trace: { | ||
| promptRouter: { | ||
| invokedModelId: | ||
| "arn:aws:bedrock:us-east-1:123456789:foundation-model/anthropic.claude-3-sonnet-20240229-v1:0", | ||
| }, | ||
| }, | ||
| }, | ||
| // Content events | ||
| { | ||
| contentBlockStart: { | ||
| start: { | ||
| text: "Hello", | ||
| }, | ||
| contentBlockIndex: 0, | ||
| }, | ||
| }, | ||
| { | ||
| contentBlockDelta: { | ||
| delta: { | ||
| text: ", world!", | ||
| }, | ||
| contentBlockIndex: 0, | ||
| }, | ||
| }, | ||
| ]), | ||
| } | ||
| }) | ||
|
|
||
| // Create a message generator | ||
| const messageGenerator = handler.createMessage("system prompt", [{ role: "user", content: "user message" }]) | ||
|
|
||
| // Collect all yielded events | ||
| const events = [] | ||
| for await (const event of messageGenerator) { | ||
| events.push(event) | ||
| } | ||
|
|
||
| // Verify that events were yielded | ||
| expect(events.length).toBeGreaterThan(0) | ||
|
|
||
| // Verify that debug logs were called | ||
| expect(logger.debug).toHaveBeenCalledWith( | ||
| "Using custom ARN for Bedrock request", | ||
| expect.objectContaining({ | ||
| ctx: "bedrock", | ||
| customArn: mockOptions.awsCustomArn, | ||
| }), | ||
| ) | ||
|
|
||
| expect(logger.debug).toHaveBeenCalledWith( | ||
| "Bedrock invokedModelId detected", | ||
| expect.objectContaining({ | ||
| ctx: "bedrock", | ||
| invokedModelId: | ||
| "arn:aws:bedrock:us-east-1:123456789:foundation-model/anthropic.claude-3-sonnet-20240229-v1:0", | ||
| }), | ||
| ) | ||
| }) | ||
|
|
||
| it("should log debug information during createMessage with cross-region inference", async () => { | ||
| // Create a handler with cross-region inference | ||
| const mockOptions: ApiHandlerOptions = { | ||
| apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0", | ||
| awsAccessKey: "test-access-key", | ||
| awsSecretKey: "test-secret-key", | ||
| awsRegion: "us-east-1", | ||
| awsUseCrossRegionInference: true, | ||
| } | ||
|
|
||
| const handler = new AwsBedrockHandler(mockOptions) | ||
|
|
||
| // Create a message generator | ||
| const messageGenerator = handler.createMessage("system prompt", [{ role: "user", content: "user message" }]) | ||
|
|
||
| // Collect all yielded events | ||
| const events = [] | ||
| for await (const event of messageGenerator) { | ||
| events.push(event) | ||
| } | ||
|
|
||
| // Verify that events were yielded | ||
| expect(events.length).toBeGreaterThan(0) | ||
| }) | ||
| }) |
Uh oh!
There was an error while loading. Please reload this page.