Skip to content

Commit 06d8dd2

Browse files
authored
Gemini caching fixes (#3096)
1 parent a2dcc18 commit 06d8dd2

File tree

14 files changed

+784
-171
lines changed

14 files changed

+784
-171
lines changed

src/api/providers/anthropic-vertex.ts

Lines changed: 5 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,14 @@ import { AnthropicVertex } from "@anthropic-ai/vertex-sdk"
33
import { GoogleAuth, JWTInput } from "google-auth-library"
44

55
import { ApiHandlerOptions, ModelInfo, vertexDefaultModelId, VertexModelId, vertexModels } from "../../shared/api"
6-
import { ApiStream } from "../transform/stream"
76
import { safeJsonParse } from "../../shared/safeJsonParse"
87

8+
import { ApiStream } from "../transform/stream"
9+
import { addCacheBreakpoints } from "../transform/caching/vertex"
10+
911
import { getModelParams, SingleCompletionHandler } from "../index"
10-
import { BaseProvider } from "./base-provider"
1112
import { ANTHROPIC_DEFAULT_MAX_TOKENS } from "./constants"
12-
import { formatMessageForCache } from "../transform/vertex-caching"
13+
import { BaseProvider } from "./base-provider"
1314

1415
// https://docs.anthropic.com/en/api/claude-on-vertex-ai
1516
export class AnthropicVertexHandler extends BaseProvider implements SingleCompletionHandler {
@@ -57,16 +58,6 @@ export class AnthropicVertexHandler extends BaseProvider implements SingleComple
5758
thinking,
5859
} = this.getModel()
5960

60-
// Find indices of user messages that we want to cache
61-
// We only cache the last two user messages to stay within the 4-block limit
62-
// (1 block for system + 1 block each for last two user messages = 3 total)
63-
const userMsgIndices = supportsPromptCache
64-
? messages.reduce((acc, msg, i) => (msg.role === "user" ? [...acc, i] : acc), [] as number[])
65-
: []
66-
67-
const lastUserMsgIndex = userMsgIndices[userMsgIndices.length - 1] ?? -1
68-
const secondLastMsgUserIndex = userMsgIndices[userMsgIndices.length - 2] ?? -1
69-
7061
/**
7162
* Vertex API has specific limitations for prompt caching:
7263
* 1. Maximum of 4 blocks can have cache_control
@@ -89,12 +80,7 @@ export class AnthropicVertexHandler extends BaseProvider implements SingleComple
8980
system: supportsPromptCache
9081
? [{ text: systemPrompt, type: "text" as const, cache_control: { type: "ephemeral" } }]
9182
: systemPrompt,
92-
messages: messages.map((message, index) => {
93-
// Only cache the last two user messages.
94-
const shouldCache =
95-
supportsPromptCache && (index === lastUserMsgIndex || index === secondLastMsgUserIndex)
96-
return formatMessageForCache(message, shouldCache)
97-
}),
83+
messages: supportsPromptCache ? addCacheBreakpoints(messages) : messages,
9884
stream: true,
9985
}
10086

src/api/providers/glama.ts

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@ import axios from "axios"
33
import OpenAI from "openai"
44

55
import { ApiHandlerOptions, glamaDefaultModelId, glamaDefaultModelInfo } from "../../shared/api"
6+
67
import { ApiStream } from "../transform/stream"
78
import { convertToOpenAiMessages } from "../transform/openai-format"
8-
import { addCacheControlDirectives } from "../transform/caching"
9+
import { addCacheBreakpoints } from "../transform/caching/anthropic"
10+
911
import { SingleCompletionHandler } from "../index"
1012
import { RouterProvider } from "./router-provider"
1113

@@ -37,7 +39,7 @@ export class GlamaHandler extends RouterProvider implements SingleCompletionHand
3739
]
3840

3941
if (modelId.startsWith("anthropic/claude-3")) {
40-
addCacheControlDirectives(systemPrompt, openAiMessages)
42+
addCacheBreakpoints(systemPrompt, openAiMessages)
4143
}
4244

4345
// Required by Anthropic; other providers default to max tokens allowed.

src/api/providers/openrouter.ts

Lines changed: 7 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,12 @@ import {
1111
OPTIONAL_PROMPT_CACHING_MODELS,
1212
REASONING_MODELS,
1313
} from "../../shared/api"
14+
1415
import { convertToOpenAiMessages } from "../transform/openai-format"
1516
import { ApiStreamChunk } from "../transform/stream"
1617
import { convertToR1Format } from "../transform/r1-format"
18+
import { addCacheBreakpoints as addAnthropicCacheBreakpoints } from "../transform/caching/anthropic"
19+
import { addCacheBreakpoints as addGeminiCacheBreakpoints } from "../transform/caching/gemini"
1720

1821
import { getModelParams, SingleCompletionHandler } from "../index"
1922
import { DEFAULT_HEADERS, DEEP_SEEK_DEFAULT_TEMPERATURE } from "./constants"
@@ -93,42 +96,11 @@ export class OpenRouterHandler extends BaseProvider implements SingleCompletionH
9396

9497
const isCacheAvailable = promptCache.supported && (!promptCache.optional || this.options.promptCachingEnabled)
9598

96-
// Prompt caching: https://openrouter.ai/docs/prompt-caching
97-
// Now with Gemini support: https://openrouter.ai/docs/features/prompt-caching
98-
// Note that we don't check the `ModelInfo` object because it is cached
99-
// in the settings for OpenRouter and the value could be stale.
99+
// https://openrouter.ai/docs/features/prompt-caching
100100
if (isCacheAvailable) {
101-
openAiMessages[0] = {
102-
role: "system",
103-
// @ts-ignore-next-line
104-
content: [{ type: "text", text: systemPrompt, cache_control: { type: "ephemeral" } }],
105-
}
106-
107-
// Add cache_control to the last two user messages
108-
// (note: this works because we only ever add one user message at a time, but if we added multiple we'd need to mark the user message before the last assistant message)
109-
const lastTwoUserMessages = openAiMessages.filter((msg) => msg.role === "user").slice(-2)
110-
111-
lastTwoUserMessages.forEach((msg) => {
112-
if (typeof msg.content === "string") {
113-
msg.content = [{ type: "text", text: msg.content }]
114-
}
115-
116-
if (Array.isArray(msg.content)) {
117-
// NOTE: This is fine since env details will always be added
118-
// at the end. But if it wasn't there, and the user added a
119-
// image_url type message, it would pop a text part before
120-
// it and then move it after to the end.
121-
let lastTextPart = msg.content.filter((part) => part.type === "text").pop()
122-
123-
if (!lastTextPart) {
124-
lastTextPart = { type: "text", text: "..." }
125-
msg.content.push(lastTextPart)
126-
}
127-
128-
// @ts-ignore-next-line
129-
lastTextPart["cache_control"] = { type: "ephemeral" }
130-
}
131-
})
101+
modelId.startsWith("google")
102+
? addGeminiCacheBreakpoints(systemPrompt, openAiMessages)
103+
: addAnthropicCacheBreakpoints(systemPrompt, openAiMessages)
132104
}
133105

134106
// https://openrouter.ai/docs/transforms

src/api/providers/unbound.ts

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@ import { Anthropic } from "@anthropic-ai/sdk"
22
import OpenAI from "openai"
33

44
import { ApiHandlerOptions, unboundDefaultModelId, unboundDefaultModelInfo } from "../../shared/api"
5+
56
import { ApiStream, ApiStreamUsageChunk } from "../transform/stream"
67
import { convertToOpenAiMessages } from "../transform/openai-format"
7-
import { addCacheControlDirectives } from "../transform/caching"
8+
import { addCacheBreakpoints } from "../transform/caching/anthropic"
9+
810
import { SingleCompletionHandler } from "../index"
911
import { RouterProvider } from "./router-provider"
1012

@@ -39,7 +41,7 @@ export class UnboundHandler extends RouterProvider implements SingleCompletionHa
3941
]
4042

4143
if (modelId.startsWith("anthropic/claude-3")) {
42-
addCacheControlDirectives(systemPrompt, openAiMessages)
44+
addCacheBreakpoints(systemPrompt, openAiMessages)
4345
}
4446

4547
// Required by Anthropic; other providers default to max tokens allowed.

src/api/transform/caching.ts

Lines changed: 0 additions & 36 deletions
This file was deleted.
Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
// npx jest src/api/transform/caching/__tests__/anthropic.test.ts
2+
3+
import OpenAI from "openai"
4+
5+
import { addCacheBreakpoints } from "../anthropic"
6+
7+
describe("addCacheBreakpoints (Anthropic)", () => {
8+
const systemPrompt = "You are a helpful assistant."
9+
10+
it("should always add a cache breakpoint to the system prompt", () => {
11+
const messages: OpenAI.Chat.ChatCompletionMessageParam[] = [
12+
{ role: "system", content: systemPrompt },
13+
{ role: "user", content: "Hello" },
14+
]
15+
16+
addCacheBreakpoints(systemPrompt, messages)
17+
18+
expect(messages[0].content).toEqual([
19+
{ type: "text", text: systemPrompt, cache_control: { type: "ephemeral" } },
20+
])
21+
})
22+
23+
it("should not add breakpoints to user messages if there are none", () => {
24+
const messages: OpenAI.Chat.ChatCompletionMessageParam[] = [{ role: "system", content: systemPrompt }]
25+
const originalMessages = JSON.parse(JSON.stringify(messages))
26+
27+
addCacheBreakpoints(systemPrompt, messages)
28+
29+
expect(messages[0].content).toEqual([
30+
{ type: "text", text: systemPrompt, cache_control: { type: "ephemeral" } },
31+
])
32+
33+
expect(messages.length).toBe(originalMessages.length)
34+
})
35+
36+
it("should add a breakpoint to the only user message if only one exists", () => {
37+
const messages: OpenAI.Chat.ChatCompletionMessageParam[] = [
38+
{ role: "system", content: systemPrompt },
39+
{ role: "user", content: "User message 1" },
40+
]
41+
42+
addCacheBreakpoints(systemPrompt, messages)
43+
44+
expect(messages[1].content).toEqual([
45+
{ type: "text", text: "User message 1", cache_control: { type: "ephemeral" } },
46+
])
47+
})
48+
49+
it("should add breakpoints to both user messages if only two exist", () => {
50+
const messages: OpenAI.Chat.ChatCompletionMessageParam[] = [
51+
{ role: "system", content: systemPrompt },
52+
{ role: "user", content: "User message 1" },
53+
{ role: "user", content: "User message 2" },
54+
]
55+
56+
addCacheBreakpoints(systemPrompt, messages)
57+
58+
expect(messages[1].content).toEqual([
59+
{ type: "text", text: "User message 1", cache_control: { type: "ephemeral" } },
60+
])
61+
62+
expect(messages[2].content).toEqual([
63+
{ type: "text", text: "User message 2", cache_control: { type: "ephemeral" } },
64+
])
65+
})
66+
67+
it("should add breakpoints to the last two user messages when more than two exist", () => {
68+
const messages: OpenAI.Chat.ChatCompletionMessageParam[] = [
69+
{ role: "system", content: systemPrompt },
70+
{ role: "user", content: "User message 1" }, // Should not get breakpoint.
71+
{ role: "user", content: "User message 2" }, // Should get breakpoint.
72+
{ role: "user", content: "User message 3" }, // Should get breakpoint.
73+
]
74+
addCacheBreakpoints(systemPrompt, messages)
75+
76+
expect(messages[1].content).toEqual([{ type: "text", text: "User message 1" }])
77+
78+
expect(messages[2].content).toEqual([
79+
{ type: "text", text: "User message 2", cache_control: { type: "ephemeral" } },
80+
])
81+
82+
expect(messages[3].content).toEqual([
83+
{ type: "text", text: "User message 3", cache_control: { type: "ephemeral" } },
84+
])
85+
})
86+
87+
it("should handle assistant messages correctly when finding last two user messages", () => {
88+
const messages: OpenAI.Chat.ChatCompletionMessageParam[] = [
89+
{ role: "system", content: systemPrompt },
90+
{ role: "user", content: "User message 1" }, // Should not get breakpoint.
91+
{ role: "assistant", content: "Assistant response 1" },
92+
{ role: "user", content: "User message 2" }, // Should get breakpoint (second to last user).
93+
{ role: "assistant", content: "Assistant response 2" },
94+
{ role: "user", content: "User message 3" }, // Should get breakpoint (last user).
95+
{ role: "assistant", content: "Assistant response 3" },
96+
]
97+
addCacheBreakpoints(systemPrompt, messages)
98+
99+
const userMessages = messages.filter((m) => m.role === "user")
100+
101+
expect(userMessages[0].content).toEqual([{ type: "text", text: "User message 1" }])
102+
103+
expect(userMessages[1].content).toEqual([
104+
{ type: "text", text: "User message 2", cache_control: { type: "ephemeral" } },
105+
])
106+
107+
expect(userMessages[2].content).toEqual([
108+
{ type: "text", text: "User message 3", cache_control: { type: "ephemeral" } },
109+
])
110+
})
111+
112+
it("should add breakpoint to the last text part if content is an array", () => {
113+
const messages: OpenAI.Chat.ChatCompletionMessageParam[] = [
114+
{ role: "system", content: systemPrompt },
115+
{ role: "user", content: "User message 1" },
116+
{
117+
role: "user",
118+
content: [
119+
{ type: "text", text: "This is the last user message." },
120+
{ type: "image_url", image_url: { url: "data:image/png;base64,..." } },
121+
{ type: "text", text: "This part should get the breakpoint." },
122+
],
123+
},
124+
]
125+
126+
addCacheBreakpoints(systemPrompt, messages)
127+
128+
expect(messages[1].content).toEqual([
129+
{ type: "text", text: "User message 1", cache_control: { type: "ephemeral" } },
130+
])
131+
132+
expect(messages[2].content).toEqual([
133+
{ type: "text", text: "This is the last user message." },
134+
{ type: "image_url", image_url: { url: "data:image/png;base64,..." } },
135+
{ type: "text", text: "This part should get the breakpoint.", cache_control: { type: "ephemeral" } },
136+
])
137+
})
138+
139+
it("should add a placeholder text part if the target message has no text parts", () => {
140+
const messages: OpenAI.Chat.ChatCompletionMessageParam[] = [
141+
{ role: "system", content: systemPrompt },
142+
{ role: "user", content: "User message 1" },
143+
{
144+
role: "user",
145+
content: [{ type: "image_url", image_url: { url: "data:image/png;base64,..." } }],
146+
},
147+
]
148+
149+
addCacheBreakpoints(systemPrompt, messages)
150+
151+
expect(messages[1].content).toEqual([
152+
{ type: "text", text: "User message 1", cache_control: { type: "ephemeral" } },
153+
])
154+
155+
expect(messages[2].content).toEqual([
156+
{ type: "image_url", image_url: { url: "data:image/png;base64,..." } },
157+
{ type: "text", text: "...", cache_control: { type: "ephemeral" } }, // Placeholder added.
158+
])
159+
})
160+
161+
it("should ensure content is array format even if no breakpoint added", () => {
162+
const messages: OpenAI.Chat.ChatCompletionMessageParam[] = [
163+
{ role: "system", content: systemPrompt },
164+
{ role: "user", content: "User message 1" }, // String content, no breakpoint.
165+
{ role: "user", content: "User message 2" }, // Gets breakpoint.
166+
{ role: "user", content: "User message 3" }, // Gets breakpoint.
167+
]
168+
169+
addCacheBreakpoints(systemPrompt, messages)
170+
171+
expect(messages[1].content).toEqual([{ type: "text", text: "User message 1" }])
172+
173+
expect(messages[2].content).toEqual([
174+
{ type: "text", text: "User message 2", cache_control: { type: "ephemeral" } },
175+
])
176+
177+
expect(messages[3].content).toEqual([
178+
{ type: "text", text: "User message 3", cache_control: { type: "ephemeral" } },
179+
])
180+
})
181+
})

0 commit comments

Comments
 (0)