Skip to content

Commit 9b267e9

Browse files
committed
Add Vertex AI prompt caching support and enhance streaming handling
- Implemented comprehensive prompt caching strategy for Vertex AI models - Added support for caching system prompts and user message text blocks - Enhanced stream processing to handle cache-related usage metrics - Updated model configurations to enable prompt caching - Improved type definitions for Vertex AI message handling
1 parent 30c1c25 commit 9b267e9

File tree

3 files changed

+435
-35
lines changed

3 files changed

+435
-35
lines changed

src/api/providers/__tests__/vertex.test.ts

Lines changed: 210 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import { Anthropic } from "@anthropic-ai/sdk"
44
import { AnthropicVertex } from "@anthropic-ai/vertex-sdk"
55

66
import { VertexHandler } from "../vertex"
7+
import { ApiStreamChunk } from "../../transform/stream"
78

89
// Mock Vertex SDK
910
jest.mock("@anthropic-ai/vertex-sdk", () => ({
@@ -128,7 +129,7 @@ describe("VertexHandler", () => {
128129
;(handler["client"].messages as any).create = mockCreate
129130

130131
const stream = handler.createMessage(systemPrompt, mockMessages)
131-
const chunks = []
132+
const chunks: ApiStreamChunk[] = []
132133

133134
for await (const chunk of stream) {
134135
chunks.push(chunk)
@@ -158,8 +159,29 @@ describe("VertexHandler", () => {
158159
model: "claude-3-5-sonnet-v2@20241022",
159160
max_tokens: 8192,
160161
temperature: 0,
161-
system: systemPrompt,
162-
messages: mockMessages,
162+
system: [
163+
{
164+
type: "text",
165+
text: "You are a helpful assistant",
166+
cache_control: { type: "ephemeral" },
167+
},
168+
],
169+
messages: [
170+
{
171+
role: "user",
172+
content: [
173+
{
174+
type: "text",
175+
text: "Hello",
176+
cache_control: { type: "ephemeral" },
177+
},
178+
],
179+
},
180+
{
181+
role: "assistant",
182+
content: "Hi there!",
183+
},
184+
],
163185
stream: true,
164186
})
165187
})
@@ -196,7 +218,7 @@ describe("VertexHandler", () => {
196218
;(handler["client"].messages as any).create = mockCreate
197219

198220
const stream = handler.createMessage(systemPrompt, mockMessages)
199-
const chunks = []
221+
const chunks: ApiStreamChunk[] = []
200222

201223
for await (const chunk of stream) {
202224
chunks.push(chunk)
@@ -230,6 +252,183 @@ describe("VertexHandler", () => {
230252
}
231253
}).rejects.toThrow("Vertex API error")
232254
})
255+
256+
it("should handle prompt caching for supported models", async () => {
257+
const mockStream = [
258+
{
259+
type: "message_start",
260+
message: {
261+
usage: {
262+
input_tokens: 10,
263+
output_tokens: 0,
264+
cache_creation_input_tokens: 3,
265+
cache_read_input_tokens: 2,
266+
},
267+
},
268+
},
269+
{
270+
type: "content_block_start",
271+
index: 0,
272+
content_block: {
273+
type: "text",
274+
text: "Hello",
275+
},
276+
},
277+
{
278+
type: "content_block_delta",
279+
delta: {
280+
type: "text_delta",
281+
text: " world!",
282+
},
283+
},
284+
{
285+
type: "message_delta",
286+
usage: {
287+
output_tokens: 5,
288+
},
289+
},
290+
]
291+
292+
const asyncIterator = {
293+
async *[Symbol.asyncIterator]() {
294+
for (const chunk of mockStream) {
295+
yield chunk
296+
}
297+
},
298+
}
299+
300+
const mockCreate = jest.fn().mockResolvedValue(asyncIterator)
301+
;(handler["client"].messages as any).create = mockCreate
302+
303+
const stream = handler.createMessage(systemPrompt, [
304+
{
305+
role: "user",
306+
content: "First message",
307+
},
308+
{
309+
role: "assistant",
310+
content: "Response",
311+
},
312+
{
313+
role: "user",
314+
content: "Second message",
315+
},
316+
])
317+
318+
const chunks: ApiStreamChunk[] = []
319+
for await (const chunk of stream) {
320+
chunks.push(chunk)
321+
}
322+
323+
// Verify usage information
324+
const usageChunks = chunks.filter((chunk) => chunk.type === "usage")
325+
expect(usageChunks).toHaveLength(2)
326+
expect(usageChunks[0]).toEqual({
327+
type: "usage",
328+
inputTokens: 10,
329+
outputTokens: 0,
330+
cacheWriteTokens: 3,
331+
cacheReadTokens: 2,
332+
})
333+
expect(usageChunks[1]).toEqual({
334+
type: "usage",
335+
inputTokens: 0,
336+
outputTokens: 5,
337+
})
338+
339+
// Verify text content
340+
const textChunks = chunks.filter((chunk) => chunk.type === "text")
341+
expect(textChunks).toHaveLength(2)
342+
expect(textChunks[0].text).toBe("Hello")
343+
expect(textChunks[1].text).toBe(" world!")
344+
345+
// Verify cache control was added correctly
346+
expect(mockCreate).toHaveBeenCalledWith(
347+
expect.objectContaining({
348+
system: [
349+
{
350+
type: "text",
351+
text: "You are a helpful assistant",
352+
cache_control: { type: "ephemeral" },
353+
},
354+
],
355+
messages: [
356+
expect.objectContaining({
357+
role: "user",
358+
content: [
359+
{
360+
type: "text",
361+
text: "First message",
362+
cache_control: { type: "ephemeral" },
363+
},
364+
],
365+
}),
366+
expect.objectContaining({
367+
role: "assistant",
368+
content: "Response",
369+
}),
370+
expect.objectContaining({
371+
role: "user",
372+
content: [
373+
{
374+
type: "text",
375+
text: "Second message",
376+
cache_control: { type: "ephemeral" },
377+
},
378+
],
379+
}),
380+
],
381+
}),
382+
)
383+
})
384+
385+
it("should handle cache-related usage metrics", async () => {
386+
const mockStream = [
387+
{
388+
type: "message_start",
389+
message: {
390+
usage: {
391+
input_tokens: 10,
392+
output_tokens: 0,
393+
cache_creation_input_tokens: 5,
394+
cache_read_input_tokens: 3,
395+
},
396+
},
397+
},
398+
{
399+
type: "content_block_start",
400+
index: 0,
401+
content_block: {
402+
type: "text",
403+
text: "Hello",
404+
},
405+
},
406+
]
407+
408+
const asyncIterator = {
409+
async *[Symbol.asyncIterator]() {
410+
for (const chunk of mockStream) {
411+
yield chunk
412+
}
413+
},
414+
}
415+
416+
const mockCreate = jest.fn().mockResolvedValue(asyncIterator)
417+
;(handler["client"].messages as any).create = mockCreate
418+
419+
const stream = handler.createMessage(systemPrompt, mockMessages)
420+
const chunks: ApiStreamChunk[] = []
421+
422+
for await (const chunk of stream) {
423+
chunks.push(chunk)
424+
}
425+
426+
// Check for cache-related metrics in usage chunk
427+
const usageChunks = chunks.filter((chunk) => chunk.type === "usage")
428+
expect(usageChunks.length).toBeGreaterThan(0)
429+
expect(usageChunks[0]).toHaveProperty("cacheWriteTokens", 5)
430+
expect(usageChunks[0]).toHaveProperty("cacheReadTokens", 3)
431+
})
233432
})
234433

235434
describe("completePrompt", () => {
@@ -240,7 +439,13 @@ describe("VertexHandler", () => {
240439
model: "claude-3-5-sonnet-v2@20241022",
241440
max_tokens: 8192,
242441
temperature: 0,
243-
messages: [{ role: "user", content: "Test prompt" }],
442+
system: "",
443+
messages: [
444+
{
445+
role: "user",
446+
content: [{ type: "text", text: "Test prompt", cache_control: { type: "ephemeral" } }],
447+
},
448+
],
244449
stream: false,
245450
})
246451
})

0 commit comments

Comments
 (0)