Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions pr-description.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# AWS Bedrock Model Updates and Cost Calculation Improvements

## Overview

This pull request updates the AWS Bedrock model definitions with the latest pricing information and improves cost calculation for API providers. The changes ensure accurate cost tracking for both standard API calls and prompt cache operations.

## Changes

### 1. Updated AWS Bedrock Model Definitions

- Updated pricing information for all AWS Bedrock models to match the published list prices for US-West-2 as of March 11, 2025
- Added support for new models:
- Amazon Nova Pro with latency optimized inference
- Meta Llama 3.3 (70B) Instruct
- Meta Llama 3.2 models (90B, 11B, 3B, 1B)
- Meta Llama 3.1 models (405B, 70B, 8B)
- Added detailed model descriptions for better user understanding
- Added `supportsComputerUse` flag to relevant models

### 2. Enhanced Cost Calculation

- Implemented a unified internal cost calculation function that handles:
- Base input token costs
- Output token costs
- Cache creation (writes) costs
- Cache read costs
- Created two specialized cost calculation functions:
- `calculateApiCostAnthropic`: For Anthropic-compliant usage where input tokens count does NOT include cached tokens
- `calculateApiCostOpenAI`: For OpenAI-compliant usage where input tokens count INCLUDES cached tokens

### 3. Improved Custom ARN Handling in Bedrock Provider

- Enhanced model detection for custom ARNs by implementing a normalized string comparison
- Added better error handling and user feedback for custom ARN issues
- Improved region handling for cross-region inference
- Fixed AWS cost calculation when using a custom ARN, including ARNs for intelligent prompt routing

### 4. Comprehensive Test Coverage

- Added extensive unit tests for both cost calculation functions
- Tests cover various scenarios including:
- Basic input/output costs
- Cache writes costs
- Cache reads costs
- Combined cost calculations
- Edge cases (missing prices, zero tokens, undefined values)

## Benefits

1. **Accurate Cost Tracking**: Users will see more accurate cost estimates for their API usage, including prompt cache operations
2. **Support for Latest Models**: Access to the newest AWS Bedrock models with correct pricing information
3. **Better Error Handling**: Improved feedback when using custom ARNs or encountering region-specific issues
4. **Consistent Cost Calculation**: Standardized approach to cost calculation across different API providers

## Testing

All tests are passing, including the new cost calculation tests and updated Bedrock provider tests.
8 changes: 6 additions & 2 deletions src/__mocks__/jest.setup.ts
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
// Mock the logger globally for all tests
jest.mock("../utils/logging", () => ({
logger: {
debug: jest.fn(),
debug: jest.fn().mockImplementation((message, meta) => {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we revert the changes to this file and mock debug within specific tests if necessary?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, done

console.log(`DEBUG: ${message}`, meta ? JSON.stringify(meta) : "")
}),
info: jest.fn(),
warn: jest.fn(),
error: jest.fn(),
fatal: jest.fn(),
child: jest.fn().mockReturnValue({
debug: jest.fn(),
debug: jest.fn().mockImplementation((message, meta) => {
console.log(`DEBUG: ${message}`, meta ? JSON.stringify(meta) : "")
}),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Think we can revert this one as well?

info: jest.fn(),
warn: jest.fn(),
error: jest.fn(),
Expand Down
151 changes: 151 additions & 0 deletions src/api/providers/__tests__/bedrock-createMessage.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
// Mock AWS SDK credential providers
jest.mock("@aws-sdk/credential-providers", () => ({
fromIni: jest.fn().mockReturnValue({
accessKeyId: "profile-access-key",
secretAccessKey: "profile-secret-key",
}),
}))

import { AwsBedrockHandler, StreamEvent } from "../bedrock"
import { ApiHandlerOptions } from "../../../shared/api"
import { BedrockRuntimeClient } from "@aws-sdk/client-bedrock-runtime"
import { logger } from "../../../utils/logging"

describe("AwsBedrockHandler createMessage", () => {
let mockSend: jest.SpyInstance

beforeEach(() => {
// Mock the BedrockRuntimeClient.prototype.send method
mockSend = jest.spyOn(BedrockRuntimeClient.prototype, "send").mockImplementation(async () => {
return {
stream: createMockStream([]),
}
})
})

afterEach(() => {
mockSend.mockRestore()
})

// Helper function to create a mock async iterable stream
function createMockStream(events: StreamEvent[]) {
return {
[Symbol.asyncIterator]: async function* () {
for (const event of events) {
yield event
}
// Always yield a metadata event at the end
yield {
metadata: {
usage: {
inputTokens: 100,
outputTokens: 200,
},
},
}
},
}
}

it("should log debug information during createMessage with custom ARN", async () => {
// Create a handler with a custom ARN
const mockOptions: ApiHandlerOptions = {
apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0",
awsAccessKey: "test-access-key",
awsSecretKey: "test-secret-key",
awsRegion: "us-east-1",
awsCustomArn: "arn:aws:bedrock:us-east-1:123456789:foundation-model/custom-model",
}

const handler = new AwsBedrockHandler(mockOptions)

// Mock the stream to include various events that trigger debug logs
mockSend.mockImplementationOnce(async () => {
return {
stream: createMockStream([
// Event with invokedModelId
{
trace: {
promptRouter: {
invokedModelId:
"arn:aws:bedrock:us-east-1:123456789:foundation-model/anthropic.claude-3-sonnet-20240229-v1:0",
},
},
},
// Content events
{
contentBlockStart: {
start: {
text: "Hello",
},
contentBlockIndex: 0,
},
},
{
contentBlockDelta: {
delta: {
text: ", world!",
},
contentBlockIndex: 0,
},
},
]),
}
})

// Create a message generator
const messageGenerator = handler.createMessage("system prompt", [{ role: "user", content: "user message" }])

// Collect all yielded events
const events = []
for await (const event of messageGenerator) {
events.push(event)
}

// Verify that events were yielded
expect(events.length).toBeGreaterThan(0)

// Verify that debug logs were called
expect(logger.debug).toHaveBeenCalledWith(
"Using custom ARN for Bedrock request",
expect.objectContaining({
ctx: "bedrock",
customArn: mockOptions.awsCustomArn,
}),
)

expect(logger.debug).toHaveBeenCalledWith(
"Bedrock invokedModelId detected",
expect.objectContaining({
ctx: "bedrock",
invokedModelId:
"arn:aws:bedrock:us-east-1:123456789:foundation-model/anthropic.claude-3-sonnet-20240229-v1:0",
}),
)
})

it("should log debug information during createMessage with cross-region inference", async () => {
// Create a handler with cross-region inference
const mockOptions: ApiHandlerOptions = {
apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0",
awsAccessKey: "test-access-key",
awsSecretKey: "test-secret-key",
awsRegion: "us-east-1",
awsUseCrossRegionInference: true,
}

const handler = new AwsBedrockHandler(mockOptions)

// Create a message generator
const messageGenerator = handler.createMessage("system prompt", [{ role: "user", content: "user message" }])

// Collect all yielded events
const events = []
for await (const event of messageGenerator) {
events.push(event)
}

// Verify that events were yielded
expect(events.length).toBeGreaterThan(0)
})
})
Loading
Loading