Skip to content

Commit bbd5b6b

Browse files
fix prompt router bug
1 parent 7feec50 commit bbd5b6b

File tree

3 files changed

+50
-204
lines changed

3 files changed

+50
-204
lines changed

src/api/providers/__tests__/bedrock-invokedModelId.test.ts

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -103,12 +103,17 @@ describe("AwsBedrockHandler with invokedModelId", () => {
103103
awsAccessKey: "test-access-key",
104104
awsSecretKey: "test-secret-key",
105105
awsRegion: "us-east-1",
106-
awsCustomArn: "arn:aws:bedrock:us-west-2:699475926481:default-prompt-router/anthropic.claude:1",
106+
awsCustomArn: "arn:aws:bedrock:us-west-2:123456789:default-prompt-router/anthropic.claude:1",
107107
}
108108

109109
const handler = new AwsBedrockHandler(mockOptions)
110110

111-
// Create a spy on the getModel method before mocking it
111+
// Verify that getModel returns the updated model info
112+
const initialModel = handler.getModel()
113+
//the default prompt router model has an input price of 3. After the stream is handled it should be updated to 8
114+
expect(initialModel.info.inputPrice).toBe(3)
115+
116+
// Create a spy on the getModel
112117
const getModelByIdSpy = jest.spyOn(handler, "getModelById")
113118

114119
// Mock the stream to include an event with invokedModelId and usage metadata
@@ -120,7 +125,7 @@ describe("AwsBedrockHandler with invokedModelId", () => {
120125
trace: {
121126
promptRouter: {
122127
invokedModelId:
123-
"arn:aws:bedrock:us-west-2:699475926481:inference-profile/us.anthropic.claude-3-5-sonnet-20240620-v1:0",
128+
"arn:aws:bedrock:us-west-2:699475926481:inference-profile/us.anthropic.claude-2-1-v1:0",
124129
usage: {
125130
inputTokens: 150,
126131
outputTokens: 250,
@@ -159,13 +164,13 @@ describe("AwsBedrockHandler with invokedModelId", () => {
159164
events.push(event)
160165
}
161166

162-
// Verify that getModelById was called with the full ARN
163-
expect(getModelByIdSpy).toHaveBeenCalledWith("anthropic.claude-3-5-sonnet-20240620-v1:0")
167+
// Verify that getModelById was called with the id, not the full arn
168+
expect(getModelByIdSpy).toHaveBeenCalledWith("anthropic.claude-2-1-v1:0", "inference-profile")
164169

165170
// Verify that getModel returns the updated model info
166171
const costModel = handler.getModel()
167172
//expect(costModel.id).toBe("anthropic.claude-3-5-sonnet-20240620-v1:0")
168-
expect(costModel.info.inputPrice).toBe(3)
173+
expect(costModel.info.inputPrice).toBe(8)
169174

170175
// Verify that a usage event was emitted after updating the costModelConfig
171176
const usageEvents = events.filter((event) => event.type === "usage")

src/api/providers/__tests__/bedrock.test.ts

Lines changed: 11 additions & 169 deletions
Original file line numberDiff line numberDiff line change
@@ -77,13 +77,11 @@ describe("AwsBedrockHandler", () => {
7777

7878
const modelInfo = customArnHandler.getModel()
7979

80-
// Verify the ARN is preserved as the ID
8180
expect(modelInfo.id).toBe(
8281
"arn:aws:bedrock:ap-northeast-3:123456789012:inference-profile/apne3.anthropic.claude-3-5-sonnet-20241022-v2:0",
83-
)
84-
85-
// Verify the model info is defined
86-
expect(modelInfo.info).toBeDefined()
82+
),
83+
// Verify the model info is defined
84+
expect(modelInfo.info).toBeDefined()
8785

8886
// Verify parseArn was called with the correct ARN
8987
expect(parseArnMock).toHaveBeenCalledWith(
@@ -102,177 +100,21 @@ describe("AwsBedrockHandler", () => {
102100
}
103101
})
104102

105-
it("should use default model when custom-arn is selected but no ARN is provided", () => {
103+
it("should use default prompt router model when prompt router arn is entered but no model can be identified from the ARN", () => {
106104
const customArnHandler = new AwsBedrockHandler({
107-
apiModelId: "custom-arn",
105+
awsCustomArn:
106+
"arn:aws:bedrock:ap-northeast-3:123456789012:default-prompt-router/my_router_arn_no_model",
108107
awsAccessKey: "test-access-key",
109108
awsSecretKey: "test-secret-key",
110109
awsRegion: "us-east-1",
111-
// No awsCustomArn provided
112110
})
113111
const modelInfo = customArnHandler.getModel()
114-
// Should fall back to default model
115-
expect(modelInfo.id).not.toBe("custom-arn")
112+
// Should fall back to default prompt router model
113+
expect(modelInfo.id).toBe(
114+
"arn:aws:bedrock:ap-northeast-3:123456789012:default-prompt-router/my_router_arn_no_model",
115+
) // bedrockDefaultPromptRouterModelId
116116
expect(modelInfo.info).toBeDefined()
117-
})
118-
})
119-
120-
describe("invokedModelId handling", () => {
121-
it("should update costModelConfig when invokedModelId is present in custom ARN scenario", async () => {
122-
const customArnHandler = new AwsBedrockHandler({
123-
apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0",
124-
awsAccessKey: "test-access-key",
125-
awsSecretKey: "test-secret-key",
126-
awsRegion: "us-east-1",
127-
awsCustomArn: "arn:aws:bedrock:us-east-1:123456789:foundation-model/custom-model",
128-
})
129-
130-
const mockStreamEvent = {
131-
trace: {
132-
promptRouter: {
133-
invokedModelId: "arn:aws:bedrock:us-east-1:123456789:foundation-model/custom-model:0",
134-
},
135-
},
136-
}
137-
138-
jest.spyOn(customArnHandler, "getModel").mockReturnValue({
139-
id: "custom-model",
140-
info: {
141-
maxTokens: 4096,
142-
contextWindow: 128_000,
143-
supportsPromptCache: false,
144-
supportsImages: true,
145-
},
146-
})
147-
148-
await customArnHandler.createMessage("system prompt", [{ role: "user", content: "user message" }]).next()
149-
150-
expect(customArnHandler.getModel()).toEqual({
151-
id: "custom-model",
152-
info: {
153-
maxTokens: 4096,
154-
contextWindow: 128_000,
155-
supportsPromptCache: false,
156-
supportsImages: true,
157-
},
158-
})
159-
})
160-
161-
it("should update costModelConfig when invokedModelId is present in default model scenario", async () => {
162-
handler = new AwsBedrockHandler({
163-
apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0",
164-
awsAccessKey: "test-access-key",
165-
awsSecretKey: "test-secret-key",
166-
awsRegion: "us-east-1",
167-
})
168-
169-
const mockStreamEvent = {
170-
trace: {
171-
promptRouter: {
172-
invokedModelId: "arn:aws:bedrock:us-east-1:123456789:foundation-model/default-model:0",
173-
},
174-
},
175-
}
176-
177-
jest.spyOn(handler, "getModel").mockReturnValue({
178-
id: "default-model",
179-
info: {
180-
maxTokens: 4096,
181-
contextWindow: 128_000,
182-
supportsPromptCache: false,
183-
supportsImages: true,
184-
},
185-
})
186-
187-
await handler.createMessage("system prompt", [{ role: "user", content: "user message" }]).next()
188-
189-
expect(handler.getModel()).toEqual({
190-
id: "default-model",
191-
info: {
192-
maxTokens: 4096,
193-
contextWindow: 128_000,
194-
supportsPromptCache: false,
195-
supportsImages: true,
196-
},
197-
})
198-
})
199-
200-
it("should not update costModelConfig when invokedModelId is not present", async () => {
201-
handler = new AwsBedrockHandler({
202-
apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0",
203-
awsAccessKey: "test-access-key",
204-
awsSecretKey: "test-secret-key",
205-
awsRegion: "us-east-1",
206-
})
207-
208-
const mockStreamEvent = {
209-
trace: {
210-
promptRouter: {
211-
// No invokedModelId present
212-
},
213-
},
214-
}
215-
216-
jest.spyOn(handler, "getModel").mockReturnValue({
217-
id: "default-model",
218-
info: {
219-
maxTokens: 4096,
220-
contextWindow: 128_000,
221-
supportsPromptCache: false,
222-
supportsImages: true,
223-
},
224-
})
225-
226-
await handler.createMessage("system prompt", [{ role: "user", content: "user message" }]).next()
227-
228-
expect(handler.getModel()).toEqual({
229-
id: "default-model",
230-
info: {
231-
maxTokens: 4096,
232-
contextWindow: 128_000,
233-
supportsPromptCache: false,
234-
supportsImages: true,
235-
},
236-
})
237-
})
238-
239-
it("should not update costModelConfig when invokedModelId cannot be parsed", async () => {
240-
handler = new AwsBedrockHandler({
241-
apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0",
242-
awsAccessKey: "test-access-key",
243-
awsSecretKey: "test-secret-key",
244-
awsRegion: "us-east-1",
245-
})
246-
247-
const mockStreamEvent = {
248-
trace: {
249-
promptRouter: {
250-
invokedModelId: "invalid-arn",
251-
},
252-
},
253-
}
254-
255-
jest.spyOn(handler, "getModel").mockReturnValue({
256-
id: "default-model",
257-
info: {
258-
maxTokens: 4096,
259-
contextWindow: 128_000,
260-
supportsPromptCache: false,
261-
supportsImages: true,
262-
},
263-
})
264-
265-
await handler.createMessage("system prompt", [{ role: "user", content: "user message" }]).next()
266-
267-
expect(handler.getModel()).toEqual({
268-
id: "default-model",
269-
info: {
270-
maxTokens: 4096,
271-
contextWindow: 128_000,
272-
supportsPromptCache: false,
273-
supportsImages: true,
274-
},
275-
})
117+
expect(modelInfo.info.maxTokens).toBe(4096)
276118
})
277119
})
278120
})

src/api/providers/bedrock.ts

Lines changed: 28 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -296,17 +296,15 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH
296296

297297
if (streamEvent?.trace?.promptRouter?.invokedModelId) {
298298
try {
299-
let invokedModelArn = this.parseArn(streamEvent.trace.promptRouter.invokedModelId)
300-
if (invokedModelArn?.modelId) {
301-
//update the in-use model info to be based on the invoked Model Id for the router
302-
//so that pricing, context window, caching etc have values that can be used
303-
//However, we want to keep the id of the model to be the ID for the router for
304-
//subsequent requests so they are sent back through the router
305-
let invokedModel = this.getModelById(invokedModelArn.modelId as string)
306-
if (invokedModel) {
307-
invokedModel.id = modelConfig.id
308-
this.costModelConfig = invokedModel
309-
}
299+
//update the in-use model info to be based on the invoked Model Id for the router
300+
//so that pricing, context window, caching etc have values that can be used
301+
//However, we want to keep the id of the model to be the ID for the router for
302+
//subsequent requests so they are sent back through the router
303+
let invokedArnInfo = this.parseArn(streamEvent.trace.promptRouter.invokedModelId)
304+
let invokedModel = this.getModelById(invokedArnInfo.modelId as string, invokedArnInfo.modelType)
305+
if (invokedModel) {
306+
invokedModel.id = modelConfig.id
307+
this.costModelConfig = invokedModel
310308
}
311309

312310
// Handle metadata events for the promptRouter.
@@ -626,26 +624,28 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH
626624
}
627625

628626
//Prompt Router responses come back in a different sequence and the model used is in the response and must be fetched by name
629-
getModelById(modelId: string): { id: BedrockModelId | string; info: SharedModelInfo } {
627+
getModelById(modelId: string, modelType?: string): { id: BedrockModelId | string; info: SharedModelInfo } {
630628
// Try to find the model in bedrockModels
631629
let baseModelId = this.parseBaseModelId(modelId)
630+
const id = baseModelId as BedrockModelId
631+
let model
632632
if (baseModelId in bedrockModels) {
633-
const id = baseModelId as BedrockModelId
634-
635633
//Do a deep copy of the model info so that later in the code the model id and maxTokens can be set.
636634
// The bedrockModels array is a constant and updating the model ID from the returned invokedModelID value
637635
// in a prompt router response isn't possible on the constant.
638-
let model = JSON.parse(JSON.stringify(bedrockModels[id]))
639-
640-
// If modelMaxTokens is explicitly set in options, override the default
641-
if (this.options.modelMaxTokens && this.options.modelMaxTokens > 0) {
642-
model.maxTokens = this.options.modelMaxTokens
643-
}
636+
model = { id: id, info: JSON.parse(JSON.stringify(bedrockModels[id])) }
637+
} else if (modelType && modelType.includes("router")) {
638+
model = this.getModelById(bedrockDefaultPromptRouterModelId as string)
639+
} else {
640+
model = this.getModelById(bedrockDefaultModelId as string)
641+
}
644642

645-
return { id, info: model }
643+
// If modelMaxTokens is explicitly set in options, override the default
644+
if (this.options.modelMaxTokens && this.options.modelMaxTokens > 0) {
645+
model.info.maxTokens = this.options.modelMaxTokens
646646
}
647647

648-
return { id: bedrockDefaultModelId, info: bedrockModels[bedrockDefaultModelId] }
648+
return model
649649
}
650650

651651
override getModel(): { id: BedrockModelId | string; info: SharedModelInfo } {
@@ -657,12 +657,7 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH
657657

658658
// If custom ARN is provided, use it
659659
if (this.options.awsCustomArn) {
660-
modelConfig = this.getModelById(this.arnInfo.modelId)
661-
662-
if (!modelConfig)
663-
// An ARN was used, but no model info match found, use default model values for cost calculations and context window
664-
// But continue using the ARN as the identifier in the Bedrock interaction
665-
modelConfig = this.getModelById(bedrockDefaultPromptRouterModelId)
660+
modelConfig = this.getModelById(this.arnInfo.modelId, this.arnInfo.modelType)
666661

667662
//If the user entered an ARN for a foundation-model they've done the same thing as picking from our list of options.
668663
//We leave the model data matching the same as if a drop-down input method was used by not overwriting the model ID with the user input ARN
@@ -864,7 +859,11 @@ Please verify:
864859
messageTemplate: `Request was throttled or rate limited. Please try:
865860
1. Reducing the frequency of requests
866861
2. If using a provisioned model, check its throughput settings
867-
3. Contact AWS support to request a quota increase if needed`,
862+
3. Contact AWS support to request a quota increase if needed
863+
864+
{formattedErrorDetails}
865+
866+
`,
868867
logLevel: "error",
869868
},
870869
TOO_MANY_TOKENS: {

0 commit comments

Comments
 (0)