Skip to content

Commit d8df9a5

Browse files
Cost display updating for Bedrock custom ARNs that are prompt routers
1 parent 99258e4 commit d8df9a5

File tree

7 files changed

+902
-114
lines changed

7 files changed

+902
-114
lines changed

pr-description.md

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# AWS Bedrock Model Updates and Cost Calculation Improvements
2+
3+
## Overview
4+
5+
This pull request updates the AWS Bedrock model definitions with the latest pricing information and improves cost calculation for API providers. The changes ensure accurate cost tracking for both standard API calls and prompt cache operations.
6+
7+
## Changes
8+
9+
### 1. Updated AWS Bedrock Model Definitions
10+
11+
- Updated pricing information for all AWS Bedrock models to match the published list prices for US-West-2 as of March 11, 2025
12+
- Added support for new models:
13+
- Amazon Nova Pro with latency optimized inference
14+
- Meta Llama 3.3 (70B) Instruct
15+
- Meta Llama 3.2 models (90B, 11B, 3B, 1B)
16+
- Meta Llama 3.1 models (405B, 70B, 8B)
17+
- Added detailed model descriptions for better user understanding
18+
- Added `supportsComputerUse` flag to relevant models
19+
20+
### 2. Enhanced Cost Calculation
21+
22+
- Implemented a unified internal cost calculation function that handles:
23+
- Base input token costs
24+
- Output token costs
25+
- Cache creation (writes) costs
26+
- Cache read costs
27+
- Created two specialized cost calculation functions:
28+
- `calculateApiCostAnthropic`: For Anthropic-compliant usage where input tokens count does NOT include cached tokens
29+
- `calculateApiCostOpenAI`: For OpenAI-compliant usage where input tokens count INCLUDES cached tokens
30+
31+
### 3. Improved Custom ARN Handling in Bedrock Provider
32+
33+
- Enhanced model detection for custom ARNs by implementing a normalized string comparison
34+
- Added better error handling and user feedback for custom ARN issues
35+
- Improved region handling for cross-region inference
36+
- Fixed AWS cost calculation when using a custom ARN, including ARNs for intelligent prompt routing
37+
38+
### 4. Comprehensive Test Coverage
39+
40+
- Added extensive unit tests for both cost calculation functions
41+
- Tests cover various scenarios including:
42+
- Basic input/output costs
43+
- Cache writes costs
44+
- Cache reads costs
45+
- Combined cost calculations
46+
- Edge cases (missing prices, zero tokens, undefined values)
47+
48+
## Benefits
49+
50+
1. **Accurate Cost Tracking**: Users will see more accurate cost estimates for their API usage, including prompt cache operations
51+
2. **Support for Latest Models**: Access to the newest AWS Bedrock models with correct pricing information
52+
3. **Better Error Handling**: Improved feedback when using custom ARNs or encountering region-specific issues
53+
4. **Consistent Cost Calculation**: Standardized approach to cost calculation across different API providers
54+
55+
## Testing
56+
57+
All tests are passing, including the new cost calculation tests and updated Bedrock provider tests.

src/__mocks__/jest.setup.ts

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,17 @@
11
// Mock the logger globally for all tests
22
jest.mock("../utils/logging", () => ({
33
logger: {
4-
debug: jest.fn(),
4+
debug: jest.fn().mockImplementation((message, meta) => {
5+
console.log(`DEBUG: ${message}`, meta ? JSON.stringify(meta) : "")
6+
}),
57
info: jest.fn(),
68
warn: jest.fn(),
79
error: jest.fn(),
810
fatal: jest.fn(),
911
child: jest.fn().mockReturnValue({
10-
debug: jest.fn(),
12+
debug: jest.fn().mockImplementation((message, meta) => {
13+
console.log(`DEBUG: ${message}`, meta ? JSON.stringify(meta) : "")
14+
}),
1115
info: jest.fn(),
1216
warn: jest.fn(),
1317
error: jest.fn(),
Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
// Mock AWS SDK credential providers
2+
jest.mock("@aws-sdk/credential-providers", () => ({
3+
fromIni: jest.fn().mockReturnValue({
4+
accessKeyId: "profile-access-key",
5+
secretAccessKey: "profile-secret-key",
6+
}),
7+
}))
8+
9+
import { AwsBedrockHandler, StreamEvent } from "../bedrock"
10+
import { ApiHandlerOptions } from "../../../shared/api"
11+
import { BedrockRuntimeClient } from "@aws-sdk/client-bedrock-runtime"
12+
import { logger } from "../../../utils/logging"
13+
14+
describe("AwsBedrockHandler createMessage", () => {
15+
let mockSend: jest.SpyInstance
16+
17+
beforeEach(() => {
18+
// Mock the BedrockRuntimeClient.prototype.send method
19+
mockSend = jest.spyOn(BedrockRuntimeClient.prototype, "send").mockImplementation(async () => {
20+
return {
21+
stream: createMockStream([]),
22+
}
23+
})
24+
})
25+
26+
afterEach(() => {
27+
mockSend.mockRestore()
28+
})
29+
30+
// Helper function to create a mock async iterable stream
31+
function createMockStream(events: StreamEvent[]) {
32+
return {
33+
[Symbol.asyncIterator]: async function* () {
34+
for (const event of events) {
35+
yield event
36+
}
37+
// Always yield a metadata event at the end
38+
yield {
39+
metadata: {
40+
usage: {
41+
inputTokens: 100,
42+
outputTokens: 200,
43+
},
44+
},
45+
}
46+
},
47+
}
48+
}
49+
50+
it("should log debug information during createMessage with custom ARN", async () => {
51+
// Create a handler with a custom ARN
52+
const mockOptions: ApiHandlerOptions = {
53+
apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0",
54+
awsAccessKey: "test-access-key",
55+
awsSecretKey: "test-secret-key",
56+
awsRegion: "us-east-1",
57+
awsCustomArn: "arn:aws:bedrock:us-east-1:123456789:foundation-model/custom-model",
58+
}
59+
60+
const handler = new AwsBedrockHandler(mockOptions)
61+
62+
// Mock the stream to include various events that trigger debug logs
63+
mockSend.mockImplementationOnce(async () => {
64+
return {
65+
stream: createMockStream([
66+
// Event with invokedModelId
67+
{
68+
trace: {
69+
promptRouter: {
70+
invokedModelId:
71+
"arn:aws:bedrock:us-east-1:123456789:foundation-model/anthropic.claude-3-sonnet-20240229-v1:0",
72+
},
73+
},
74+
},
75+
// Content events
76+
{
77+
contentBlockStart: {
78+
start: {
79+
text: "Hello",
80+
},
81+
contentBlockIndex: 0,
82+
},
83+
},
84+
{
85+
contentBlockDelta: {
86+
delta: {
87+
text: ", world!",
88+
},
89+
contentBlockIndex: 0,
90+
},
91+
},
92+
]),
93+
}
94+
})
95+
96+
// Create a message generator
97+
const messageGenerator = handler.createMessage("system prompt", [{ role: "user", content: "user message" }])
98+
99+
// Collect all yielded events
100+
const events = []
101+
for await (const event of messageGenerator) {
102+
events.push(event)
103+
}
104+
105+
// Verify that events were yielded
106+
expect(events.length).toBeGreaterThan(0)
107+
108+
// Verify that debug logs were called
109+
expect(logger.debug).toHaveBeenCalledWith(
110+
"Using custom ARN for Bedrock request",
111+
expect.objectContaining({
112+
ctx: "bedrock",
113+
customArn: mockOptions.awsCustomArn,
114+
}),
115+
)
116+
117+
expect(logger.debug).toHaveBeenCalledWith(
118+
"Bedrock invokedModelId detected",
119+
expect.objectContaining({
120+
ctx: "bedrock",
121+
invokedModelId:
122+
"arn:aws:bedrock:us-east-1:123456789:foundation-model/anthropic.claude-3-sonnet-20240229-v1:0",
123+
}),
124+
)
125+
})
126+
127+
it("should log debug information during createMessage with cross-region inference", async () => {
128+
// Create a handler with cross-region inference
129+
const mockOptions: ApiHandlerOptions = {
130+
apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0",
131+
awsAccessKey: "test-access-key",
132+
awsSecretKey: "test-secret-key",
133+
awsRegion: "us-east-1",
134+
awsUseCrossRegionInference: true,
135+
}
136+
137+
const handler = new AwsBedrockHandler(mockOptions)
138+
139+
// Create a message generator
140+
const messageGenerator = handler.createMessage("system prompt", [{ role: "user", content: "user message" }])
141+
142+
// Collect all yielded events
143+
const events = []
144+
for await (const event of messageGenerator) {
145+
events.push(event)
146+
}
147+
148+
// Verify that events were yielded
149+
expect(events.length).toBeGreaterThan(0)
150+
})
151+
})

0 commit comments

Comments
 (0)