Skip to content

Commit b077267

Browse files
fixes image support in bedrock. regression from prompt cache implementation (#2723)
fixes image support in bedrock. regression created during prompt caching implementation
1 parent d86d601 commit b077267

File tree

2 files changed

+175
-33
lines changed

2 files changed

+175
-33
lines changed

src/api/providers/__tests__/bedrock.test.ts

Lines changed: 140 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,23 @@ jest.mock("@aws-sdk/credential-providers", () => {
77
return { fromIni: mockFromIni }
88
})
99

10+
// Mock BedrockRuntimeClient and ConverseStreamCommand
11+
const mockConverseStreamCommand = jest.fn()
12+
const mockSend = jest.fn().mockResolvedValue({
13+
stream: [],
14+
})
15+
16+
jest.mock("@aws-sdk/client-bedrock-runtime", () => ({
17+
BedrockRuntimeClient: jest.fn().mockImplementation(() => ({
18+
send: mockSend,
19+
})),
20+
ConverseStreamCommand: mockConverseStreamCommand,
21+
ConverseCommand: jest.fn(),
22+
}))
23+
1024
import { AwsBedrockHandler } from "../bedrock"
1125
import { MessageContent } from "../../../shared/api"
12-
import { BedrockRuntimeClient } from "@aws-sdk/client-bedrock-runtime"
26+
import { BedrockRuntimeClient, ConverseStreamCommand } from "@aws-sdk/client-bedrock-runtime"
1327
import { Anthropic } from "@anthropic-ai/sdk"
1428
const { fromIni } = require("@aws-sdk/credential-providers")
1529
import { logger } from "../../../utils/logging"
@@ -57,20 +71,18 @@ describe("AwsBedrockHandler", () => {
5771
})
5872

5973
it("should handle inference-profile ARN with apne3 region prefix", () => {
60-
// Mock the parseArn method before creating the handler
6174
const originalParseArn = AwsBedrockHandler.prototype["parseArn"]
6275
const parseArnMock = jest.fn().mockImplementation(function (this: any, arn: string, region?: string) {
6376
return originalParseArn.call(this, arn, region)
6477
})
6578
AwsBedrockHandler.prototype["parseArn"] = parseArnMock
6679

6780
try {
68-
// Create a handler with a custom ARN that includes the apne3. region prefix
6981
const customArnHandler = new AwsBedrockHandler({
7082
apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0",
7183
awsAccessKey: "test-access-key",
7284
awsSecretKey: "test-secret-key",
73-
awsRegion: "ap-northeast-3", // Osaka region
85+
awsRegion: "ap-northeast-3",
7486
awsCustomArn:
7587
"arn:aws:bedrock:ap-northeast-3:123456789012:inference-profile/apne3.anthropic.claude-3-5-sonnet-20241022-v2:0",
7688
})
@@ -79,23 +91,17 @@ describe("AwsBedrockHandler", () => {
7991

8092
expect(modelInfo.id).toBe(
8193
"arn:aws:bedrock:ap-northeast-3:123456789012:inference-profile/apne3.anthropic.claude-3-5-sonnet-20241022-v2:0",
82-
),
83-
// Verify the model info is defined
84-
expect(modelInfo.info).toBeDefined()
94+
)
95+
expect(modelInfo.info).toBeDefined()
8596

86-
// Verify parseArn was called with the correct ARN
8797
expect(parseArnMock).toHaveBeenCalledWith(
8898
"arn:aws:bedrock:ap-northeast-3:123456789012:inference-profile/apne3.anthropic.claude-3-5-sonnet-20241022-v2:0",
8999
"ap-northeast-3",
90100
)
91101

92-
// Verify the model ID was correctly extracted from the ARN (without the region prefix)
93102
expect((customArnHandler as any).arnInfo.modelId).toBe("anthropic.claude-3-5-sonnet-20241022-v2:0")
94-
95-
// Verify cross-region inference flag is false since apne3 is a prefix for a single region
96103
expect((customArnHandler as any).arnInfo.crossRegionInference).toBe(false)
97104
} finally {
98-
// Restore the original method
99105
AwsBedrockHandler.prototype["parseArn"] = originalParseArn
100106
}
101107
})
@@ -109,12 +115,132 @@ describe("AwsBedrockHandler", () => {
109115
awsRegion: "us-east-1",
110116
})
111117
const modelInfo = customArnHandler.getModel()
112-
// Should fall back to default prompt router model
113118
expect(modelInfo.id).toBe(
114119
"arn:aws:bedrock:ap-northeast-3:123456789012:default-prompt-router/my_router_arn_no_model",
115-
) // bedrockDefaultPromptRouterModelId
120+
)
116121
expect(modelInfo.info).toBeDefined()
117122
expect(modelInfo.info.maxTokens).toBe(4096)
118123
})
119124
})
125+
126+
describe("image handling", () => {
127+
const mockImageData = Buffer.from("test-image-data").toString("base64")
128+
129+
beforeEach(() => {
130+
// Reset the mocks before each test
131+
mockSend.mockReset()
132+
mockConverseStreamCommand.mockReset()
133+
134+
mockSend.mockResolvedValue({
135+
stream: [],
136+
})
137+
})
138+
139+
it("should properly convert image content to Bedrock format", async () => {
140+
const messages: Anthropic.Messages.MessageParam[] = [
141+
{
142+
role: "user",
143+
content: [
144+
{
145+
type: "image",
146+
source: {
147+
type: "base64",
148+
data: mockImageData,
149+
media_type: "image/jpeg",
150+
},
151+
},
152+
{
153+
type: "text",
154+
text: "What's in this image?",
155+
},
156+
],
157+
},
158+
]
159+
160+
const generator = handler.createMessage("", messages)
161+
await generator.next() // Start the generator
162+
163+
// Verify the command was created with the right payload
164+
expect(mockConverseStreamCommand).toHaveBeenCalled()
165+
const commandArg = mockConverseStreamCommand.mock.calls[0][0]
166+
167+
// Verify the image was properly formatted
168+
const imageBlock = commandArg.messages[0].content[0]
169+
expect(imageBlock).toHaveProperty("image")
170+
expect(imageBlock.image).toHaveProperty("format", "jpeg")
171+
expect(imageBlock.image.source).toHaveProperty("bytes")
172+
expect(imageBlock.image.source.bytes).toBeInstanceOf(Uint8Array)
173+
})
174+
175+
it("should reject unsupported image formats", async () => {
176+
const messages: Anthropic.Messages.MessageParam[] = [
177+
{
178+
role: "user",
179+
content: [
180+
{
181+
type: "image",
182+
source: {
183+
type: "base64",
184+
data: mockImageData,
185+
media_type: "image/tiff" as "image/jpeg", // Type assertion to bypass TS
186+
},
187+
},
188+
],
189+
},
190+
]
191+
192+
const generator = handler.createMessage("", messages)
193+
await expect(generator.next()).rejects.toThrow("Unsupported image format: tiff")
194+
})
195+
196+
it("should handle multiple images in a single message", async () => {
197+
const messages: Anthropic.Messages.MessageParam[] = [
198+
{
199+
role: "user",
200+
content: [
201+
{
202+
type: "image",
203+
source: {
204+
type: "base64",
205+
data: mockImageData,
206+
media_type: "image/jpeg",
207+
},
208+
},
209+
{
210+
type: "text",
211+
text: "First image",
212+
},
213+
{
214+
type: "image",
215+
source: {
216+
type: "base64",
217+
data: mockImageData,
218+
media_type: "image/png",
219+
},
220+
},
221+
{
222+
type: "text",
223+
text: "Second image",
224+
},
225+
],
226+
},
227+
]
228+
229+
const generator = handler.createMessage("", messages)
230+
await generator.next() // Start the generator
231+
232+
// Verify the command was created with the right payload
233+
expect(mockConverseStreamCommand).toHaveBeenCalled()
234+
const commandArg = mockConverseStreamCommand.mock.calls[0][0]
235+
236+
// Verify both images were properly formatted
237+
const firstImage = commandArg.messages[0].content[0]
238+
const secondImage = commandArg.messages[0].content[2]
239+
240+
expect(firstImage).toHaveProperty("image")
241+
expect(firstImage.image).toHaveProperty("format", "jpeg")
242+
expect(secondImage).toHaveProperty("image")
243+
expect(secondImage.image).toHaveProperty("format", "png")
244+
})
245+
})
120246
})

src/api/providers/bedrock.ts

Lines changed: 35 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ import {
33
ConverseStreamCommand,
44
ConverseCommand,
55
BedrockRuntimeClientConfig,
6+
ContentBlock,
67
} from "@aws-sdk/client-bedrock-runtime"
78
import { fromIni } from "@aws-sdk/credential-providers"
89
import { Anthropic } from "@anthropic-ai/sdk"
@@ -23,6 +24,7 @@ import { Message, SystemContentBlock } from "@aws-sdk/client-bedrock-runtime"
2324
import { MultiPointStrategy } from "../transform/cache-strategy/multi-point-strategy"
2425
import { ModelInfo as CacheModelInfo } from "../transform/cache-strategy/types"
2526
import { AMAZON_BEDROCK_REGION_INFO } from "../../shared/aws_regions"
27+
import { convertToBedrockConverseMessages as sharedConverter } from "../transform/bedrock-converse-format"
2628

2729
const BEDROCK_DEFAULT_TEMPERATURE = 0.3
2830
const BEDROCK_MAX_TOKENS = 4096
@@ -434,7 +436,18 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH
434436
modelInfo?: any,
435437
conversationId?: string, // Optional conversation ID to track cache points across messages
436438
): { system: SystemContentBlock[]; messages: Message[] } {
437-
// Convert model info to expected format
439+
// First convert messages using shared converter for proper image handling
440+
const convertedMessages = sharedConverter(anthropicMessages as Anthropic.Messages.MessageParam[])
441+
442+
// If prompt caching is disabled, return the converted messages directly
443+
if (!usePromptCache) {
444+
return {
445+
system: systemMessage ? [{ text: systemMessage } as SystemContentBlock] : [],
446+
messages: convertedMessages,
447+
}
448+
}
449+
450+
// Convert model info to expected format for cache strategy
438451
const cacheModelInfo: CacheModelInfo = {
439452
maxTokens: modelInfo?.maxTokens || 8192,
440453
contextWindow: modelInfo?.contextWindow || 200_000,
@@ -444,18 +457,6 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH
444457
cachableFields: modelInfo?.cachableFields || [],
445458
}
446459

447-
// Clean messages by removing any existing cache points
448-
const cleanedMessages = anthropicMessages.map((msg) => {
449-
if (typeof msg.content === "string") {
450-
return msg
451-
}
452-
const cleaned = {
453-
...msg,
454-
content: this.removeCachePoints(msg.content),
455-
}
456-
return cleaned
457-
})
458-
459460
// Get previous cache point placements for this conversation if available
460461
const previousPlacements =
461462
conversationId && this.previousCachePointPlacements[conversationId]
@@ -466,21 +467,36 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH
466467
const config = {
467468
modelInfo: cacheModelInfo,
468469
systemPrompt: systemMessage,
469-
messages: cleanedMessages as Anthropic.Messages.MessageParam[],
470+
messages: anthropicMessages as Anthropic.Messages.MessageParam[],
470471
usePromptCache,
471472
previousCachePointPlacements: previousPlacements,
472473
}
473474

474-
// Determine optimal cache points
475+
// Get cache point placements
475476
let strategy = new MultiPointStrategy(config)
476-
const result = strategy.determineOptimalCachePoints()
477+
const cacheResult = strategy.determineOptimalCachePoints()
477478

478479
// Store cache point placements for future use if conversation ID is provided
479-
if (conversationId && result.messageCachePointPlacements) {
480-
this.previousCachePointPlacements[conversationId] = result.messageCachePointPlacements
480+
if (conversationId && cacheResult.messageCachePointPlacements) {
481+
this.previousCachePointPlacements[conversationId] = cacheResult.messageCachePointPlacements
481482
}
482483

483-
return result
484+
// Apply cache points to the properly converted messages
485+
const messagesWithCache = convertedMessages.map((msg, index) => {
486+
const placement = cacheResult.messageCachePointPlacements?.find((p) => p.index === index)
487+
if (placement) {
488+
return {
489+
...msg,
490+
content: [...(msg.content || []), { cachePoint: { type: "default" } } as ContentBlock],
491+
}
492+
}
493+
return msg
494+
})
495+
496+
return {
497+
system: systemMessage ? [{ text: systemMessage } as SystemContentBlock] : [],
498+
messages: messagesWithCache,
499+
}
484500
}
485501

486502
/************************************************************************************

0 commit comments

Comments
 (0)