Skip to content

Commit fbdf758

Browse files
authored
Merge pull request #1604 from Smartsheet-JB-Brown/jbbrown/bedrock_cost_intelligent_prompt_routing
Cost display updating for Bedrock custom ARNs that are prompt routers
2 parents cd9fb24 + 4ada518 commit fbdf758

File tree

4 files changed

+613
-117
lines changed

4 files changed

+613
-117
lines changed
Lines changed: 313 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,313 @@
1+
// Mock AWS SDK credential providers
2+
jest.mock("@aws-sdk/credential-providers", () => ({
3+
fromIni: jest.fn().mockReturnValue({
4+
accessKeyId: "profile-access-key",
5+
secretAccessKey: "profile-secret-key",
6+
}),
7+
}))
8+
9+
import { AwsBedrockHandler, StreamEvent } from "../bedrock"
10+
import { ApiHandlerOptions } from "../../../shared/api"
11+
import { BedrockRuntimeClient } from "@aws-sdk/client-bedrock-runtime"
12+
13+
describe("AwsBedrockHandler with invokedModelId", () => {
14+
let mockSend: jest.SpyInstance
15+
16+
beforeEach(() => {
17+
// Mock the BedrockRuntimeClient.prototype.send method
18+
mockSend = jest.spyOn(BedrockRuntimeClient.prototype, "send").mockImplementation(async () => {
19+
return {
20+
stream: createMockStream([]),
21+
}
22+
})
23+
})
24+
25+
afterEach(() => {
26+
mockSend.mockRestore()
27+
})
28+
29+
// Helper function to create a mock async iterable stream
30+
function createMockStream(events: StreamEvent[]) {
31+
return {
32+
[Symbol.asyncIterator]: async function* () {
33+
for (const event of events) {
34+
yield event
35+
}
36+
// Always yield a metadata event at the end
37+
yield {
38+
metadata: {
39+
usage: {
40+
inputTokens: 100,
41+
outputTokens: 200,
42+
},
43+
},
44+
}
45+
},
46+
}
47+
}
48+
49+
it("should update costModelConfig when invokedModelId is present in the stream", async () => {
50+
// Create a handler with a custom ARN
51+
const mockOptions: ApiHandlerOptions = {
52+
// apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0",
53+
awsAccessKey: "test-access-key",
54+
awsSecretKey: "test-secret-key",
55+
awsRegion: "us-east-1",
56+
awsCustomArn: "arn:aws:bedrock:us-west-2:699475926481:default-prompt-router/anthropic.claude:1",
57+
}
58+
59+
const handler = new AwsBedrockHandler(mockOptions)
60+
61+
// Create a spy on the getModel method before mocking it
62+
const getModelSpy = jest.spyOn(handler, "getModelByName")
63+
64+
// Mock the stream to include an event with invokedModelId and usage metadata
65+
mockSend.mockImplementationOnce(async () => {
66+
return {
67+
stream: createMockStream([
68+
// First event with invokedModelId and usage metadata
69+
{
70+
trace: {
71+
promptRouter: {
72+
invokedModelId:
73+
"arn:aws:bedrock:us-west-2:699475926481:inference-profile/us.anthropic.claude-3-5-sonnet-20240620-v1:0",
74+
usage: {
75+
inputTokens: 150,
76+
outputTokens: 250,
77+
},
78+
},
79+
},
80+
// Some content events
81+
},
82+
{
83+
contentBlockStart: {
84+
start: {
85+
text: "Hello",
86+
},
87+
contentBlockIndex: 0,
88+
},
89+
},
90+
{
91+
contentBlockDelta: {
92+
delta: {
93+
text: ", world!",
94+
},
95+
contentBlockIndex: 0,
96+
},
97+
},
98+
]),
99+
}
100+
})
101+
102+
// Create a message generator
103+
const messageGenerator = handler.createMessage("system prompt", [{ role: "user", content: "user message" }])
104+
105+
// Collect all yielded events to verify usage events
106+
const events = []
107+
for await (const event of messageGenerator) {
108+
events.push(event)
109+
}
110+
111+
// Verify that getModel was called with the correct model name
112+
expect(getModelSpy).toHaveBeenCalledWith("anthropic.claude-3-5-sonnet-20240620-v1:0")
113+
114+
// Verify that getModel returns the updated model info
115+
const costModel = handler.getModel()
116+
expect(costModel.id).toBe("anthropic.claude-3-5-sonnet-20240620-v1:0")
117+
expect(costModel.info.inputPrice).toBe(3)
118+
119+
// Verify that a usage event was emitted after updating the costModelConfig
120+
const usageEvents = events.filter((event) => event.type === "usage")
121+
expect(usageEvents.length).toBeGreaterThanOrEqual(1)
122+
123+
// The last usage event should have the token counts from the metadata
124+
const lastUsageEvent = usageEvents[usageEvents.length - 1]
125+
expect(lastUsageEvent).toEqual({
126+
type: "usage",
127+
inputTokens: 100,
128+
outputTokens: 200,
129+
})
130+
})
131+
132+
it("should not update costModelConfig when invokedModelId is not present", async () => {
133+
// Create a handler with default settings
134+
const mockOptions: ApiHandlerOptions = {
135+
apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0",
136+
awsAccessKey: "test-access-key",
137+
awsSecretKey: "test-secret-key",
138+
awsRegion: "us-east-1",
139+
}
140+
141+
const handler = new AwsBedrockHandler(mockOptions)
142+
143+
// Mock the stream without an invokedModelId event
144+
mockSend.mockImplementationOnce(async () => {
145+
return {
146+
stream: createMockStream([
147+
// Some content events but no invokedModelId
148+
{
149+
contentBlockStart: {
150+
start: {
151+
text: "Hello",
152+
},
153+
contentBlockIndex: 0,
154+
},
155+
},
156+
{
157+
contentBlockDelta: {
158+
delta: {
159+
text: ", world!",
160+
},
161+
contentBlockIndex: 0,
162+
},
163+
},
164+
]),
165+
}
166+
})
167+
168+
// Mock getModel to return expected values
169+
const getModelSpy = jest.spyOn(handler, "getModel").mockReturnValue({
170+
id: "anthropic.claude-3-5-sonnet-20241022-v2:0",
171+
info: {
172+
maxTokens: 4096,
173+
contextWindow: 128_000,
174+
supportsPromptCache: false,
175+
supportsImages: true,
176+
},
177+
})
178+
179+
// Create a message generator
180+
const messageGenerator = handler.createMessage("system prompt", [{ role: "user", content: "user message" }])
181+
182+
// Consume the generator
183+
for await (const _ of messageGenerator) {
184+
// Just consume the messages
185+
}
186+
187+
// Verify that getModel returns the original model info
188+
const costModel = handler.getModel()
189+
expect(costModel.id).toBe("anthropic.claude-3-5-sonnet-20241022-v2:0")
190+
191+
// Verify getModel was not called with a model name parameter
192+
expect(getModelSpy).not.toHaveBeenCalledWith(expect.any(String))
193+
})
194+
195+
it("should handle invalid invokedModelId format gracefully", async () => {
196+
// Create a handler with default settings
197+
const mockOptions: ApiHandlerOptions = {
198+
apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0",
199+
awsAccessKey: "test-access-key",
200+
awsSecretKey: "test-secret-key",
201+
awsRegion: "us-east-1",
202+
}
203+
204+
const handler = new AwsBedrockHandler(mockOptions)
205+
206+
// Mock the stream with an invalid invokedModelId
207+
mockSend.mockImplementationOnce(async () => {
208+
return {
209+
stream: createMockStream([
210+
// Event with invalid invokedModelId format
211+
{
212+
trace: {
213+
promptRouter: {
214+
invokedModelId: "invalid-format-not-an-arn",
215+
},
216+
},
217+
},
218+
// Some content events
219+
{
220+
contentBlockStart: {
221+
start: {
222+
text: "Hello",
223+
},
224+
contentBlockIndex: 0,
225+
},
226+
},
227+
]),
228+
}
229+
})
230+
231+
// Mock getModel to return expected values
232+
const getModelSpy = jest.spyOn(handler, "getModel").mockReturnValue({
233+
id: "anthropic.claude-3-5-sonnet-20241022-v2:0",
234+
info: {
235+
maxTokens: 4096,
236+
contextWindow: 128_000,
237+
supportsPromptCache: false,
238+
supportsImages: true,
239+
},
240+
})
241+
242+
// Create a message generator
243+
const messageGenerator = handler.createMessage("system prompt", [{ role: "user", content: "user message" }])
244+
245+
// Consume the generator
246+
for await (const _ of messageGenerator) {
247+
// Just consume the messages
248+
}
249+
250+
// Verify that getModel returns the original model info
251+
const costModel = handler.getModel()
252+
expect(costModel.id).toBe("anthropic.claude-3-5-sonnet-20241022-v2:0")
253+
})
254+
255+
it("should handle errors during invokedModelId processing", async () => {
256+
// Create a handler with default settings
257+
const mockOptions: ApiHandlerOptions = {
258+
apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0",
259+
awsAccessKey: "test-access-key",
260+
awsSecretKey: "test-secret-key",
261+
awsRegion: "us-east-1",
262+
}
263+
264+
const handler = new AwsBedrockHandler(mockOptions)
265+
266+
// Mock the stream with a valid invokedModelId
267+
mockSend.mockImplementationOnce(async () => {
268+
return {
269+
stream: createMockStream([
270+
// Event with valid invokedModelId
271+
{
272+
trace: {
273+
promptRouter: {
274+
invokedModelId:
275+
"arn:aws:bedrock:us-east-1:123456789:foundation-model/anthropic.claude-3-sonnet-20240229-v1:0",
276+
},
277+
},
278+
},
279+
]),
280+
}
281+
})
282+
283+
// Mock getModel to throw an error when called with the model name
284+
jest.spyOn(handler, "getModel").mockImplementation((modelName?: string) => {
285+
if (modelName === "anthropic.claude-3-sonnet-20240229-v1:0") {
286+
throw new Error("Test error during model lookup")
287+
}
288+
289+
// Default return value for initial call
290+
return {
291+
id: "anthropic.claude-3-5-sonnet-20241022-v2:0",
292+
info: {
293+
maxTokens: 4096,
294+
contextWindow: 128_000,
295+
supportsPromptCache: false,
296+
supportsImages: true,
297+
},
298+
}
299+
})
300+
301+
// Create a message generator
302+
const messageGenerator = handler.createMessage("system prompt", [{ role: "user", content: "user message" }])
303+
304+
// Consume the generator
305+
for await (const _ of messageGenerator) {
306+
// Just consume the messages
307+
}
308+
309+
// Verify that getModel returns the original model info
310+
const costModel = handler.getModel()
311+
expect(costModel.id).toBe("anthropic.claude-3-5-sonnet-20241022-v2:0")
312+
})
313+
})

0 commit comments

Comments
 (0)