diff --git a/src/api/providers/__tests__/vertex.test.ts b/src/api/providers/__tests__/vertex.test.ts index ebe60ba0c68..6e81fd771b7 100644 --- a/src/api/providers/__tests__/vertex.test.ts +++ b/src/api/providers/__tests__/vertex.test.ts @@ -4,6 +4,7 @@ import { Anthropic } from "@anthropic-ai/sdk" import { AnthropicVertex } from "@anthropic-ai/vertex-sdk" import { VertexHandler } from "../vertex" +import { ApiStreamChunk } from "../../transform/stream" // Mock Vertex SDK jest.mock("@anthropic-ai/vertex-sdk", () => ({ @@ -128,7 +129,7 @@ describe("VertexHandler", () => { ;(handler["client"].messages as any).create = mockCreate const stream = handler.createMessage(systemPrompt, mockMessages) - const chunks = [] + const chunks: ApiStreamChunk[] = [] for await (const chunk of stream) { chunks.push(chunk) @@ -158,8 +159,29 @@ describe("VertexHandler", () => { model: "claude-3-5-sonnet-v2@20241022", max_tokens: 8192, temperature: 0, - system: systemPrompt, - messages: mockMessages, + system: [ + { + type: "text", + text: "You are a helpful assistant", + cache_control: { type: "ephemeral" }, + }, + ], + messages: [ + { + role: "user", + content: [ + { + type: "text", + text: "Hello", + cache_control: { type: "ephemeral" }, + }, + ], + }, + { + role: "assistant", + content: "Hi there!", + }, + ], stream: true, }) }) @@ -196,7 +218,7 @@ describe("VertexHandler", () => { ;(handler["client"].messages as any).create = mockCreate const stream = handler.createMessage(systemPrompt, mockMessages) - const chunks = [] + const chunks: ApiStreamChunk[] = [] for await (const chunk of stream) { chunks.push(chunk) @@ -230,6 +252,183 @@ describe("VertexHandler", () => { } }).rejects.toThrow("Vertex API error") }) + + it("should handle prompt caching for supported models", async () => { + const mockStream = [ + { + type: "message_start", + message: { + usage: { + input_tokens: 10, + output_tokens: 0, + cache_creation_input_tokens: 3, + cache_read_input_tokens: 2, + }, + }, + }, + { + type: "content_block_start", + index: 0, + content_block: { + type: "text", + text: "Hello", + }, + }, + { + type: "content_block_delta", + delta: { + type: "text_delta", + text: " world!", + }, + }, + { + type: "message_delta", + usage: { + output_tokens: 5, + }, + }, + ] + + const asyncIterator = { + async *[Symbol.asyncIterator]() { + for (const chunk of mockStream) { + yield chunk + } + }, + } + + const mockCreate = jest.fn().mockResolvedValue(asyncIterator) + ;(handler["client"].messages as any).create = mockCreate + + const stream = handler.createMessage(systemPrompt, [ + { + role: "user", + content: "First message", + }, + { + role: "assistant", + content: "Response", + }, + { + role: "user", + content: "Second message", + }, + ]) + + const chunks: ApiStreamChunk[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + // Verify usage information + const usageChunks = chunks.filter((chunk) => chunk.type === "usage") + expect(usageChunks).toHaveLength(2) + expect(usageChunks[0]).toEqual({ + type: "usage", + inputTokens: 10, + outputTokens: 0, + cacheWriteTokens: 3, + cacheReadTokens: 2, + }) + expect(usageChunks[1]).toEqual({ + type: "usage", + inputTokens: 0, + outputTokens: 5, + }) + + // Verify text content + const textChunks = chunks.filter((chunk) => chunk.type === "text") + expect(textChunks).toHaveLength(2) + expect(textChunks[0].text).toBe("Hello") + expect(textChunks[1].text).toBe(" world!") + + // Verify cache control was added correctly + expect(mockCreate).toHaveBeenCalledWith( + expect.objectContaining({ + system: [ + { + type: "text", + text: "You are a helpful assistant", + cache_control: { type: "ephemeral" }, + }, + ], + messages: [ + expect.objectContaining({ + role: "user", + content: [ + { + type: "text", + text: "First message", + cache_control: { type: "ephemeral" }, + }, + ], + }), + expect.objectContaining({ + role: "assistant", + content: "Response", + }), + expect.objectContaining({ + role: "user", + content: [ + { + type: "text", + text: "Second message", + cache_control: { type: "ephemeral" }, + }, + ], + }), + ], + }), + ) + }) + + it("should handle cache-related usage metrics", async () => { + const mockStream = [ + { + type: "message_start", + message: { + usage: { + input_tokens: 10, + output_tokens: 0, + cache_creation_input_tokens: 5, + cache_read_input_tokens: 3, + }, + }, + }, + { + type: "content_block_start", + index: 0, + content_block: { + type: "text", + text: "Hello", + }, + }, + ] + + const asyncIterator = { + async *[Symbol.asyncIterator]() { + for (const chunk of mockStream) { + yield chunk + } + }, + } + + const mockCreate = jest.fn().mockResolvedValue(asyncIterator) + ;(handler["client"].messages as any).create = mockCreate + + const stream = handler.createMessage(systemPrompt, mockMessages) + const chunks: ApiStreamChunk[] = [] + + for await (const chunk of stream) { + chunks.push(chunk) + } + + // Check for cache-related metrics in usage chunk + const usageChunks = chunks.filter((chunk) => chunk.type === "usage") + expect(usageChunks.length).toBeGreaterThan(0) + expect(usageChunks[0]).toHaveProperty("cacheWriteTokens", 5) + expect(usageChunks[0]).toHaveProperty("cacheReadTokens", 3) + }) }) describe("completePrompt", () => { @@ -240,7 +439,13 @@ describe("VertexHandler", () => { model: "claude-3-5-sonnet-v2@20241022", max_tokens: 8192, temperature: 0, - messages: [{ role: "user", content: "Test prompt" }], + system: "", + messages: [ + { + role: "user", + content: [{ type: "text", text: "Test prompt", cache_control: { type: "ephemeral" } }], + }, + ], stream: false, }) }) diff --git a/src/api/providers/vertex.ts b/src/api/providers/vertex.ts index 0ee22e5893d..70562766c3b 100644 --- a/src/api/providers/vertex.ts +++ b/src/api/providers/vertex.ts @@ -1,9 +1,86 @@ import { Anthropic } from "@anthropic-ai/sdk" import { AnthropicVertex } from "@anthropic-ai/vertex-sdk" +import { Stream as AnthropicStream } from "@anthropic-ai/sdk/streaming" import { ApiHandler, SingleCompletionHandler } from "../" import { ApiHandlerOptions, ModelInfo, vertexDefaultModelId, VertexModelId, vertexModels } from "../../shared/api" import { ApiStream } from "../transform/stream" +// Types for Vertex SDK + +/** + * Vertex API has specific limitations for prompt caching: + * 1. Maximum of 4 blocks can have cache_control + * 2. Only text blocks can be cached (images and other content types cannot) + * 3. Cache control can only be applied to user messages, not assistant messages + * + * Our caching strategy: + * - Cache the system prompt (1 block) + * - Cache the last text block of the second-to-last user message (1 block) + * - Cache the last text block of the last user message (1 block) + * This ensures we stay under the 4-block limit while maintaining effective caching + * for the most relevant context. + */ + +interface VertexTextBlock { + type: "text" + text: string + cache_control?: { type: "ephemeral" } +} + +interface VertexImageBlock { + type: "image" + source: { + type: "base64" + media_type: "image/jpeg" | "image/png" | "image/gif" | "image/webp" + data: string + } +} + +type VertexContentBlock = VertexTextBlock | VertexImageBlock + +interface VertexUsage { + input_tokens?: number + output_tokens?: number + cache_creation_input_tokens?: number + cache_read_input_tokens?: number +} + +interface VertexMessage extends Omit { + content: string | VertexContentBlock[] +} + +interface VertexMessageCreateParams { + model: string + max_tokens: number + temperature: number + system: string | VertexTextBlock[] + messages: VertexMessage[] + stream: boolean +} + +interface VertexMessageResponse { + content: Array<{ type: "text"; text: string }> +} + +interface VertexMessageStreamEvent { + type: "message_start" | "message_delta" | "content_block_start" | "content_block_delta" + message?: { + usage: VertexUsage + } + usage?: { + output_tokens: number + } + content_block?: { + type: "text" + text: string + } + index?: number + delta?: { + type: "text_delta" + text: string + } +} + // https://docs.anthropic.com/en/api/claude-on-vertex-ai export class VertexHandler implements ApiHandler, SingleCompletionHandler { private options: ApiHandlerOptions @@ -18,37 +95,120 @@ export class VertexHandler implements ApiHandler, SingleCompletionHandler { }) } + private formatMessageForCache(message: Anthropic.Messages.MessageParam, shouldCache: boolean): VertexMessage { + // Assistant messages are kept as-is since they can't be cached + if (message.role === "assistant") { + return message as VertexMessage + } + + // For string content, we convert to array format with optional cache control + if (typeof message.content === "string") { + return { + ...message, + content: [ + { + type: "text" as const, + text: message.content, + // For string content, we only have one block so it's always the last + ...(shouldCache && { cache_control: { type: "ephemeral" } }), + }, + ], + } + } + + // For array content, find the last text block index once before mapping + const lastTextBlockIndex = message.content.reduce( + (lastIndex, content, index) => (content.type === "text" ? index : lastIndex), + -1, + ) + + // Then use this pre-calculated index in the map function + return { + ...message, + content: message.content.map((content, contentIndex) => { + // Images and other non-text content are passed through unchanged + if (content.type === "image") { + return content as VertexImageBlock + } + + // Check if this is the last text block using our pre-calculated index + const isLastTextBlock = contentIndex === lastTextBlockIndex + + return { + type: "text" as const, + text: (content as { text: string }).text, + ...(shouldCache && isLastTextBlock && { cache_control: { type: "ephemeral" } }), + } + }), + } + } + async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream { - const stream = await this.client.messages.create({ - model: this.getModel().id, - max_tokens: this.getModel().info.maxTokens || 8192, + const model = this.getModel() + const useCache = model.info.supportsPromptCache + + // Find indices of user messages that we want to cache + // We only cache the last two user messages to stay within the 4-block limit + // (1 block for system + 1 block each for last two user messages = 3 total) + const userMsgIndices = useCache + ? messages.reduce((acc, msg, i) => (msg.role === "user" ? [...acc, i] : acc), [] as number[]) + : [] + const lastUserMsgIndex = userMsgIndices[userMsgIndices.length - 1] ?? -1 + const secondLastMsgUserIndex = userMsgIndices[userMsgIndices.length - 2] ?? -1 + + // Create the stream with appropriate caching configuration + const params = { + model: model.id, + max_tokens: model.info.maxTokens || 8192, temperature: this.options.modelTemperature ?? 0, - system: systemPrompt, - messages, + // Cache the system prompt if caching is enabled + system: useCache + ? [ + { + text: systemPrompt, + type: "text" as const, + cache_control: { type: "ephemeral" }, + }, + ] + : systemPrompt, + messages: messages.map((message, index) => { + // Only cache the last two user messages + const shouldCache = useCache && (index === lastUserMsgIndex || index === secondLastMsgUserIndex) + return this.formatMessageForCache(message, shouldCache) + }), stream: true, - }) + } + + const stream = (await this.client.messages.create( + params as Anthropic.Messages.MessageCreateParamsStreaming, + )) as unknown as AnthropicStream + + // Process the stream chunks for await (const chunk of stream) { switch (chunk.type) { - case "message_start": - const usage = chunk.message.usage + case "message_start": { + const usage = chunk.message!.usage yield { type: "usage", inputTokens: usage.input_tokens || 0, outputTokens: usage.output_tokens || 0, + cacheWriteTokens: usage.cache_creation_input_tokens, + cacheReadTokens: usage.cache_read_input_tokens, } break - case "message_delta": + } + case "message_delta": { yield { type: "usage", inputTokens: 0, - outputTokens: chunk.usage.output_tokens || 0, + outputTokens: chunk.usage!.output_tokens || 0, } break - - case "content_block_start": - switch (chunk.content_block.type) { - case "text": - if (chunk.index > 0) { + } + case "content_block_start": { + switch (chunk.content_block!.type) { + case "text": { + if (chunk.index! > 0) { yield { type: "text", text: "\n", @@ -56,21 +216,25 @@ export class VertexHandler implements ApiHandler, SingleCompletionHandler { } yield { type: "text", - text: chunk.content_block.text, + text: chunk.content_block!.text, } break + } } break - case "content_block_delta": - switch (chunk.delta.type) { - case "text_delta": + } + case "content_block_delta": { + switch (chunk.delta!.type) { + case "text_delta": { yield { type: "text", - text: chunk.delta.text, + text: chunk.delta!.text, } break + } } break + } } } } @@ -86,13 +250,34 @@ export class VertexHandler implements ApiHandler, SingleCompletionHandler { async completePrompt(prompt: string): Promise { try { - const response = await this.client.messages.create({ - model: this.getModel().id, - max_tokens: this.getModel().info.maxTokens || 8192, + const model = this.getModel() + const useCache = model.info.supportsPromptCache + + const params = { + model: model.id, + max_tokens: model.info.maxTokens || 8192, temperature: this.options.modelTemperature ?? 0, - messages: [{ role: "user", content: prompt }], + system: "", // No system prompt needed for single completions + messages: [ + { + role: "user", + content: useCache + ? [ + { + type: "text" as const, + text: prompt, + cache_control: { type: "ephemeral" }, + }, + ] + : prompt, + }, + ], stream: false, - }) + } + + const response = (await this.client.messages.create( + params as Anthropic.Messages.MessageCreateParamsNonStreaming, + )) as unknown as VertexMessageResponse const content = response.content[0] if (content.type === "text") { diff --git a/src/shared/api.ts b/src/shared/api.ts index e7e4c54db6a..cd6aead1a59 100644 --- a/src/shared/api.ts +++ b/src/shared/api.ts @@ -441,7 +441,7 @@ export const vertexModels = { contextWindow: 200_000, supportsImages: true, supportsComputerUse: true, - supportsPromptCache: false, + supportsPromptCache: true, inputPrice: 3.0, outputPrice: 15.0, }, @@ -450,41 +450,51 @@ export const vertexModels = { contextWindow: 200_000, supportsImages: true, supportsComputerUse: true, - supportsPromptCache: false, + supportsPromptCache: true, inputPrice: 3.0, outputPrice: 15.0, + cacheWritesPrice: 3.75, + cacheReadsPrice: 0.3, }, "claude-3-5-sonnet@20240620": { maxTokens: 8192, contextWindow: 200_000, supportsImages: true, - supportsPromptCache: false, + supportsPromptCache: true, inputPrice: 3.0, outputPrice: 15.0, + cacheWritesPrice: 3.75, + cacheReadsPrice: 0.3, }, "claude-3-5-haiku@20241022": { maxTokens: 8192, contextWindow: 200_000, supportsImages: false, - supportsPromptCache: false, + supportsPromptCache: true, inputPrice: 1.0, outputPrice: 5.0, + cacheWritesPrice: 1.25, + cacheReadsPrice: 0.1, }, "claude-3-opus@20240229": { maxTokens: 4096, contextWindow: 200_000, supportsImages: true, - supportsPromptCache: false, + supportsPromptCache: true, inputPrice: 15.0, outputPrice: 75.0, + cacheWritesPrice: 18.75, + cacheReadsPrice: 1.5, }, "claude-3-haiku@20240307": { maxTokens: 4096, contextWindow: 200_000, supportsImages: true, - supportsPromptCache: false, + supportsPromptCache: true, inputPrice: 0.25, outputPrice: 1.25, + cacheWritesPrice: 0.3, + cacheReadsPrice: 0.03, }, } as const satisfies Record