Skip to content

Commit a3f1a3f

Browse files
authored
Gemini caching improvements (RooCodeInc#2925)
1 parent 7f99c06 commit a3f1a3f

File tree

4 files changed

+99
-31
lines changed

4 files changed

+99
-31
lines changed

.changeset/bright-singers-drop.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
"roo-cline": patch
3+
---
4+
5+
Enable prompt caching for Gemini (with some improvements)

src/api/providers/gemini.ts

Lines changed: 85 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -4,27 +4,42 @@ import {
44
type GenerateContentResponseUsageMetadata,
55
type GenerateContentParameters,
66
type Content,
7+
CreateCachedContentConfig,
78
} from "@google/genai"
9+
import NodeCache from "node-cache"
810

911
import { SingleCompletionHandler } from "../"
1012
import type { ApiHandlerOptions, GeminiModelId, ModelInfo } from "../../shared/api"
1113
import { geminiDefaultModelId, geminiModels } from "../../shared/api"
12-
import { convertAnthropicContentToGemini, convertAnthropicMessageToGemini } from "../transform/gemini-format"
14+
import {
15+
convertAnthropicContentToGemini,
16+
convertAnthropicMessageToGemini,
17+
getMessagesLength,
18+
} from "../transform/gemini-format"
1319
import type { ApiStream } from "../transform/stream"
1420
import { BaseProvider } from "./base-provider"
1521

1622
const CACHE_TTL = 5
1723

24+
const CONTEXT_CACHE_TOKEN_MINIMUM = 4096
25+
26+
type CacheEntry = {
27+
key: string
28+
count: number
29+
}
30+
1831
export class GeminiHandler extends BaseProvider implements SingleCompletionHandler {
1932
protected options: ApiHandlerOptions
33+
2034
private client: GoogleGenAI
21-
private contentCaches: Map<string, { key: string; count: number }>
35+
private contentCaches: NodeCache
36+
private isCacheBusy = false
2237

2338
constructor(options: ApiHandlerOptions) {
2439
super()
2540
this.options = options
2641
this.client = new GoogleGenAI({ apiKey: options.geminiApiKey ?? "not-provided" })
27-
this.contentCaches = new Map()
42+
this.contentCaches = new NodeCache({ stdTTL: 5 * 60, checkperiod: 5 * 60 })
2843
}
2944

3045
async *createMessage(
@@ -35,36 +50,76 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl
3550
const { id: model, thinkingConfig, maxOutputTokens, info } = this.getModel()
3651

3752
const contents = messages.map(convertAnthropicMessageToGemini)
53+
const contentsLength = systemInstruction.length + getMessagesLength(contents)
54+
3855
let uncachedContent: Content[] | undefined = undefined
3956
let cachedContent: string | undefined = undefined
4057
let cacheWriteTokens: number | undefined = undefined
4158

59+
// The minimum input token count for context caching is 4,096.
60+
// For a basic approximation we assume 4 characters per token.
61+
// We can use tiktoken eventually to get a more accurat token count.
4262
// https://ai.google.dev/gemini-api/docs/caching?lang=node
43-
// if (info.supportsPromptCache && cacheKey) {
44-
// const cacheEntry = this.contentCaches.get(cacheKey)
45-
46-
// if (cacheEntry) {
47-
// uncachedContent = contents.slice(cacheEntry.count, contents.length)
48-
// cachedContent = cacheEntry.key
49-
// }
63+
// https://ai.google.dev/gemini-api/docs/tokens?lang=node
64+
const isCacheAvailable =
65+
info.supportsPromptCache &&
66+
this.options.promptCachingEnabled &&
67+
cacheKey &&
68+
contentsLength > 4 * CONTEXT_CACHE_TOKEN_MINIMUM
69+
70+
if (isCacheAvailable) {
71+
const cacheEntry = this.contentCaches.get<CacheEntry>(cacheKey)
72+
73+
if (cacheEntry) {
74+
uncachedContent = contents.slice(cacheEntry.count, contents.length)
75+
cachedContent = cacheEntry.key
76+
console.log(
77+
`[GeminiHandler] using ${cacheEntry.count} cached messages (${cacheEntry.key}) and ${uncachedContent.length} uncached messages`,
78+
)
79+
}
5080

51-
// const newCacheEntry = await this.client.caches.create({
52-
// model,
53-
// config: { contents, systemInstruction, ttl: `${CACHE_TTL * 60}s` },
54-
// })
81+
if (!this.isCacheBusy) {
82+
this.isCacheBusy = true
83+
const timestamp = Date.now()
84+
85+
this.client.caches
86+
.create({
87+
model,
88+
config: {
89+
contents,
90+
systemInstruction,
91+
ttl: `${CACHE_TTL * 60}s`,
92+
httpOptions: { timeout: 120_000 },
93+
},
94+
})
95+
.then((result) => {
96+
const { name, usageMetadata } = result
97+
98+
if (name) {
99+
this.contentCaches.set<CacheEntry>(cacheKey, { key: name, count: contents.length })
100+
cacheWriteTokens = usageMetadata?.totalTokenCount ?? 0
101+
console.log(
102+
`[GeminiHandler] cached ${contents.length} messages (${cacheWriteTokens} tokens) in ${Date.now() - timestamp}ms`,
103+
)
104+
}
105+
})
106+
.catch((error) => {
107+
console.error(`[GeminiHandler] caches.create error`, error)
108+
})
109+
.finally(() => {
110+
this.isCacheBusy = false
111+
})
112+
}
113+
}
55114

56-
// if (newCacheEntry.name) {
57-
// this.contentCaches.set(cacheKey, { key: newCacheEntry.name, count: contents.length })
58-
// cacheWriteTokens = newCacheEntry.usageMetadata?.totalTokenCount ?? 0
59-
// }
60-
// }
115+
const isCacheUsed = !!cachedContent
61116

62117
const params: GenerateContentParameters = {
63118
model,
64119
contents: uncachedContent ?? contents,
65120
config: {
66121
cachedContent,
67-
systemInstruction: cachedContent ? undefined : systemInstruction,
122+
systemInstruction: isCacheUsed ? undefined : systemInstruction,
68123
httpOptions: this.options.googleGeminiBaseUrl
69124
? { baseUrl: this.options.googleGeminiBaseUrl }
70125
: undefined,
@@ -94,13 +149,15 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl
94149
const cacheReadTokens = lastUsageMetadata.cachedContentTokenCount
95150
const reasoningTokens = lastUsageMetadata.thoughtsTokenCount
96151

97-
// const totalCost = this.calculateCost({
98-
// info,
99-
// inputTokens,
100-
// outputTokens,
101-
// cacheWriteTokens,
102-
// cacheReadTokens,
103-
// })
152+
const totalCost = isCacheUsed
153+
? this.calculateCost({
154+
info,
155+
inputTokens,
156+
outputTokens,
157+
cacheWriteTokens,
158+
cacheReadTokens,
159+
})
160+
: undefined
104161

105162
yield {
106163
type: "usage",
@@ -109,7 +166,7 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl
109166
cacheWriteTokens,
110167
cacheReadTokens,
111168
reasoningTokens,
112-
// totalCost,
169+
totalCost,
113170
}
114171
}
115172
}

src/api/transform/gemini-format.ts

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,3 +76,9 @@ export function convertAnthropicMessageToGemini(message: Anthropic.Messages.Mess
7676
parts: convertAnthropicContentToGemini(message.content),
7777
}
7878
}
79+
80+
const getContentLength = ({ parts }: Content): number =>
81+
parts?.reduce((length, { text }) => length + (text?.length ?? 0), 0) ?? 0
82+
83+
export const getMessagesLength = (contents: Content[]): number =>
84+
contents.reduce((length, content) => length + getContentLength(content), 0)

src/shared/api.ts

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -679,7 +679,7 @@ export const geminiModels = {
679679
maxTokens: 65_535,
680680
contextWindow: 1_048_576,
681681
supportsImages: true,
682-
supportsPromptCache: false,
682+
supportsPromptCache: true,
683683
isPromptCacheOptional: true,
684684
inputPrice: 2.5, // This is the pricing for prompts above 200k tokens.
685685
outputPrice: 15,
@@ -704,7 +704,7 @@ export const geminiModels = {
704704
maxTokens: 8192,
705705
contextWindow: 1_048_576,
706706
supportsImages: true,
707-
supportsPromptCache: false,
707+
supportsPromptCache: true,
708708
isPromptCacheOptional: true,
709709
inputPrice: 0.1,
710710
outputPrice: 0.4,
@@ -755,7 +755,7 @@ export const geminiModels = {
755755
maxTokens: 8192,
756756
contextWindow: 1_048_576,
757757
supportsImages: true,
758-
supportsPromptCache: false,
758+
supportsPromptCache: true,
759759
isPromptCacheOptional: true,
760760
inputPrice: 0.15, // This is the pricing for prompts above 128k tokens.
761761
outputPrice: 0.6,

0 commit comments

Comments
 (0)