Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 140 additions & 14 deletions src/api/providers/__tests__/bedrock.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,23 @@ jest.mock("@aws-sdk/credential-providers", () => {
return { fromIni: mockFromIni }
})

// Mock BedrockRuntimeClient and ConverseStreamCommand
const mockConverseStreamCommand = jest.fn()
const mockSend = jest.fn().mockResolvedValue({
stream: [],
})

jest.mock("@aws-sdk/client-bedrock-runtime", () => ({
BedrockRuntimeClient: jest.fn().mockImplementation(() => ({
send: mockSend,
})),
ConverseStreamCommand: mockConverseStreamCommand,
ConverseCommand: jest.fn(),
}))

import { AwsBedrockHandler } from "../bedrock"
import { MessageContent } from "../../../shared/api"
import { BedrockRuntimeClient } from "@aws-sdk/client-bedrock-runtime"
import { BedrockRuntimeClient, ConverseStreamCommand } from "@aws-sdk/client-bedrock-runtime"
import { Anthropic } from "@anthropic-ai/sdk"
const { fromIni } = require("@aws-sdk/credential-providers")
import { logger } from "../../../utils/logging"
Expand Down Expand Up @@ -57,20 +71,18 @@ describe("AwsBedrockHandler", () => {
})

it("should handle inference-profile ARN with apne3 region prefix", () => {
// Mock the parseArn method before creating the handler
const originalParseArn = AwsBedrockHandler.prototype["parseArn"]
const parseArnMock = jest.fn().mockImplementation(function (this: any, arn: string, region?: string) {
return originalParseArn.call(this, arn, region)
})
AwsBedrockHandler.prototype["parseArn"] = parseArnMock

try {
// Create a handler with a custom ARN that includes the apne3. region prefix
const customArnHandler = new AwsBedrockHandler({
apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0",
awsAccessKey: "test-access-key",
awsSecretKey: "test-secret-key",
awsRegion: "ap-northeast-3", // Osaka region
awsRegion: "ap-northeast-3",
awsCustomArn:
"arn:aws:bedrock:ap-northeast-3:123456789012:inference-profile/apne3.anthropic.claude-3-5-sonnet-20241022-v2:0",
})
Expand All @@ -79,23 +91,17 @@ describe("AwsBedrockHandler", () => {

expect(modelInfo.id).toBe(
"arn:aws:bedrock:ap-northeast-3:123456789012:inference-profile/apne3.anthropic.claude-3-5-sonnet-20241022-v2:0",
),
// Verify the model info is defined
expect(modelInfo.info).toBeDefined()
)
expect(modelInfo.info).toBeDefined()

// Verify parseArn was called with the correct ARN
expect(parseArnMock).toHaveBeenCalledWith(
"arn:aws:bedrock:ap-northeast-3:123456789012:inference-profile/apne3.anthropic.claude-3-5-sonnet-20241022-v2:0",
"ap-northeast-3",
)

// Verify the model ID was correctly extracted from the ARN (without the region prefix)
expect((customArnHandler as any).arnInfo.modelId).toBe("anthropic.claude-3-5-sonnet-20241022-v2:0")

// Verify cross-region inference flag is false since apne3 is a prefix for a single region
expect((customArnHandler as any).arnInfo.crossRegionInference).toBe(false)
} finally {
// Restore the original method
AwsBedrockHandler.prototype["parseArn"] = originalParseArn
}
})
Expand All @@ -109,12 +115,132 @@ describe("AwsBedrockHandler", () => {
awsRegion: "us-east-1",
})
const modelInfo = customArnHandler.getModel()
// Should fall back to default prompt router model
expect(modelInfo.id).toBe(
"arn:aws:bedrock:ap-northeast-3:123456789012:default-prompt-router/my_router_arn_no_model",
) // bedrockDefaultPromptRouterModelId
)
expect(modelInfo.info).toBeDefined()
expect(modelInfo.info.maxTokens).toBe(4096)
})
})

describe("image handling", () => {
const mockImageData = Buffer.from("test-image-data").toString("base64")

beforeEach(() => {
// Reset the mocks before each test
mockSend.mockReset()
mockConverseStreamCommand.mockReset()

mockSend.mockResolvedValue({
stream: [],
})
})

it("should properly convert image content to Bedrock format", async () => {
const messages: Anthropic.Messages.MessageParam[] = [
{
role: "user",
content: [
{
type: "image",
source: {
type: "base64",
data: mockImageData,
media_type: "image/jpeg",
},
},
{
type: "text",
text: "What's in this image?",
},
],
},
]

const generator = handler.createMessage("", messages)
await generator.next() // Start the generator

// Verify the command was created with the right payload
expect(mockConverseStreamCommand).toHaveBeenCalled()
const commandArg = mockConverseStreamCommand.mock.calls[0][0]

// Verify the image was properly formatted
const imageBlock = commandArg.messages[0].content[0]
expect(imageBlock).toHaveProperty("image")
expect(imageBlock.image).toHaveProperty("format", "jpeg")
expect(imageBlock.image.source).toHaveProperty("bytes")
expect(imageBlock.image.source.bytes).toBeInstanceOf(Uint8Array)
})

it("should reject unsupported image formats", async () => {
const messages: Anthropic.Messages.MessageParam[] = [
{
role: "user",
content: [
{
type: "image",
source: {
type: "base64",
data: mockImageData,
media_type: "image/tiff" as "image/jpeg", // Type assertion to bypass TS
},
},
],
},
]

const generator = handler.createMessage("", messages)
await expect(generator.next()).rejects.toThrow("Unsupported image format: tiff")
})

it("should handle multiple images in a single message", async () => {
const messages: Anthropic.Messages.MessageParam[] = [
{
role: "user",
content: [
{
type: "image",
source: {
type: "base64",
data: mockImageData,
media_type: "image/jpeg",
},
},
{
type: "text",
text: "First image",
},
{
type: "image",
source: {
type: "base64",
data: mockImageData,
media_type: "image/png",
},
},
{
type: "text",
text: "Second image",
},
],
},
]

const generator = handler.createMessage("", messages)
await generator.next() // Start the generator

// Verify the command was created with the right payload
expect(mockConverseStreamCommand).toHaveBeenCalled()
const commandArg = mockConverseStreamCommand.mock.calls[0][0]

// Verify both images were properly formatted
const firstImage = commandArg.messages[0].content[0]
const secondImage = commandArg.messages[0].content[2]

expect(firstImage).toHaveProperty("image")
expect(firstImage.image).toHaveProperty("format", "jpeg")
expect(secondImage).toHaveProperty("image")
expect(secondImage.image).toHaveProperty("format", "png")
})
})
})
54 changes: 35 additions & 19 deletions src/api/providers/bedrock.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import {
ConverseStreamCommand,
ConverseCommand,
BedrockRuntimeClientConfig,
ContentBlock,
} from "@aws-sdk/client-bedrock-runtime"
import { fromIni } from "@aws-sdk/credential-providers"
import { Anthropic } from "@anthropic-ai/sdk"
Expand All @@ -23,6 +24,7 @@ import { Message, SystemContentBlock } from "@aws-sdk/client-bedrock-runtime"
import { MultiPointStrategy } from "../transform/cache-strategy/multi-point-strategy"
import { ModelInfo as CacheModelInfo } from "../transform/cache-strategy/types"
import { AMAZON_BEDROCK_REGION_INFO } from "../../shared/aws_regions"
import { convertToBedrockConverseMessages as sharedConverter } from "../transform/bedrock-converse-format"

const BEDROCK_DEFAULT_TEMPERATURE = 0.3
const BEDROCK_MAX_TOKENS = 4096
Expand Down Expand Up @@ -434,7 +436,18 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH
modelInfo?: any,
conversationId?: string, // Optional conversation ID to track cache points across messages
): { system: SystemContentBlock[]; messages: Message[] } {
// Convert model info to expected format
// First convert messages using shared converter for proper image handling
const convertedMessages = sharedConverter(anthropicMessages as Anthropic.Messages.MessageParam[])

// If prompt caching is disabled, return the converted messages directly
if (!usePromptCache) {
return {
system: systemMessage ? [{ text: systemMessage } as SystemContentBlock] : [],
messages: convertedMessages,
}
}

// Convert model info to expected format for cache strategy
const cacheModelInfo: CacheModelInfo = {
maxTokens: modelInfo?.maxTokens || 8192,
contextWindow: modelInfo?.contextWindow || 200_000,
Expand All @@ -444,18 +457,6 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH
cachableFields: modelInfo?.cachableFields || [],
}

// Clean messages by removing any existing cache points
const cleanedMessages = anthropicMessages.map((msg) => {
if (typeof msg.content === "string") {
return msg
}
const cleaned = {
...msg,
content: this.removeCachePoints(msg.content),
}
return cleaned
})

// Get previous cache point placements for this conversation if available
const previousPlacements =
conversationId && this.previousCachePointPlacements[conversationId]
Expand All @@ -466,21 +467,36 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH
const config = {
modelInfo: cacheModelInfo,
systemPrompt: systemMessage,
messages: cleanedMessages as Anthropic.Messages.MessageParam[],
messages: anthropicMessages as Anthropic.Messages.MessageParam[],
usePromptCache,
previousCachePointPlacements: previousPlacements,
}

// Determine optimal cache points
// Get cache point placements
let strategy = new MultiPointStrategy(config)
const result = strategy.determineOptimalCachePoints()
const cacheResult = strategy.determineOptimalCachePoints()

// Store cache point placements for future use if conversation ID is provided
if (conversationId && result.messageCachePointPlacements) {
this.previousCachePointPlacements[conversationId] = result.messageCachePointPlacements
if (conversationId && cacheResult.messageCachePointPlacements) {
this.previousCachePointPlacements[conversationId] = cacheResult.messageCachePointPlacements
}

return result
// Apply cache points to the properly converted messages
const messagesWithCache = convertedMessages.map((msg, index) => {
const placement = cacheResult.messageCachePointPlacements?.find((p) => p.index === index)
if (placement) {
return {
...msg,
content: [...(msg.content || []), { cachePoint: { type: "default" } } as ContentBlock],
}
}
return msg
})

return {
system: systemMessage ? [{ text: systemMessage } as SystemContentBlock] : [],
messages: messagesWithCache,
}
}

/************************************************************************************
Expand Down
Loading