Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/bright-singers-drop.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"roo-cline": patch
---

Enable prompt caching for Gemini (with some improvements)
113 changes: 85 additions & 28 deletions src/api/providers/gemini.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,42 @@ import {
type GenerateContentResponseUsageMetadata,
type GenerateContentParameters,
type Content,
CreateCachedContentConfig,
} from "@google/genai"
import NodeCache from "node-cache"

import { SingleCompletionHandler } from "../"
import type { ApiHandlerOptions, GeminiModelId, ModelInfo } from "../../shared/api"
import { geminiDefaultModelId, geminiModels } from "../../shared/api"
import { convertAnthropicContentToGemini, convertAnthropicMessageToGemini } from "../transform/gemini-format"
import {
convertAnthropicContentToGemini,
convertAnthropicMessageToGemini,
getMessagesLength,
} from "../transform/gemini-format"
import type { ApiStream } from "../transform/stream"
import { BaseProvider } from "./base-provider"

const CACHE_TTL = 5

const CONTEXT_CACHE_TOKEN_MINIMUM = 4096

type CacheEntry = {
key: string
count: number
}

export class GeminiHandler extends BaseProvider implements SingleCompletionHandler {
protected options: ApiHandlerOptions

private client: GoogleGenAI
private contentCaches: Map<string, { key: string; count: number }>
private contentCaches: NodeCache
private isCacheBusy = false

constructor(options: ApiHandlerOptions) {
super()
this.options = options
this.client = new GoogleGenAI({ apiKey: options.geminiApiKey ?? "not-provided" })
this.contentCaches = new Map()
this.contentCaches = new NodeCache({ stdTTL: 5 * 60, checkperiod: 5 * 60 })
}

async *createMessage(
Expand All @@ -35,36 +50,76 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl
const { id: model, thinkingConfig, maxOutputTokens, info } = this.getModel()

const contents = messages.map(convertAnthropicMessageToGemini)
const contentsLength = systemInstruction.length + getMessagesLength(contents)

let uncachedContent: Content[] | undefined = undefined
let cachedContent: string | undefined = undefined
let cacheWriteTokens: number | undefined = undefined

// The minimum input token count for context caching is 4,096.
// For a basic approximation we assume 4 characters per token.
// We can use tiktoken eventually to get a more accurat token count.
// https://ai.google.dev/gemini-api/docs/caching?lang=node
// if (info.supportsPromptCache && cacheKey) {
// const cacheEntry = this.contentCaches.get(cacheKey)

// if (cacheEntry) {
// uncachedContent = contents.slice(cacheEntry.count, contents.length)
// cachedContent = cacheEntry.key
// }
// https://ai.google.dev/gemini-api/docs/tokens?lang=node
const isCacheAvailable =
info.supportsPromptCache &&
this.options.promptCachingEnabled &&
cacheKey &&
contentsLength > 4 * CONTEXT_CACHE_TOKEN_MINIMUM

if (isCacheAvailable) {
const cacheEntry = this.contentCaches.get<CacheEntry>(cacheKey)

if (cacheEntry) {
uncachedContent = contents.slice(cacheEntry.count, contents.length)
cachedContent = cacheEntry.key
console.log(
`[GeminiHandler] using ${cacheEntry.count} cached messages (${cacheEntry.key}) and ${uncachedContent.length} uncached messages`,
)
}

// const newCacheEntry = await this.client.caches.create({
// model,
// config: { contents, systemInstruction, ttl: `${CACHE_TTL * 60}s` },
// })
if (!this.isCacheBusy) {
this.isCacheBusy = true
const timestamp = Date.now()

this.client.caches
.create({
model,
config: {
contents,
systemInstruction,
ttl: `${CACHE_TTL * 60}s`,
httpOptions: { timeout: 120_000 },
},
})
.then((result) => {
const { name, usageMetadata } = result

if (name) {
this.contentCaches.set<CacheEntry>(cacheKey, { key: name, count: contents.length })
cacheWriteTokens = usageMetadata?.totalTokenCount ?? 0
console.log(
`[GeminiHandler] cached ${contents.length} messages (${cacheWriteTokens} tokens) in ${Date.now() - timestamp}ms`,
)
}
})
.catch((error) => {
console.error(`[GeminiHandler] caches.create error`, error)
})
.finally(() => {
this.isCacheBusy = false
})
}
}

// if (newCacheEntry.name) {
// this.contentCaches.set(cacheKey, { key: newCacheEntry.name, count: contents.length })
// cacheWriteTokens = newCacheEntry.usageMetadata?.totalTokenCount ?? 0
// }
// }
const isCacheUsed = !!cachedContent

const params: GenerateContentParameters = {
model,
contents: uncachedContent ?? contents,
config: {
cachedContent,
systemInstruction: cachedContent ? undefined : systemInstruction,
systemInstruction: isCacheUsed ? undefined : systemInstruction,
httpOptions: this.options.googleGeminiBaseUrl
? { baseUrl: this.options.googleGeminiBaseUrl }
: undefined,
Expand Down Expand Up @@ -94,13 +149,15 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl
const cacheReadTokens = lastUsageMetadata.cachedContentTokenCount
const reasoningTokens = lastUsageMetadata.thoughtsTokenCount

// const totalCost = this.calculateCost({
// info,
// inputTokens,
// outputTokens,
// cacheWriteTokens,
// cacheReadTokens,
// })
const totalCost = isCacheUsed
? this.calculateCost({
info,
inputTokens,
outputTokens,
cacheWriteTokens,
cacheReadTokens,
})
: undefined

yield {
type: "usage",
Expand All @@ -109,7 +166,7 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl
cacheWriteTokens,
cacheReadTokens,
reasoningTokens,
// totalCost,
totalCost,
}
}
}
Expand Down
6 changes: 6 additions & 0 deletions src/api/transform/gemini-format.ts
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,9 @@ export function convertAnthropicMessageToGemini(message: Anthropic.Messages.Mess
parts: convertAnthropicContentToGemini(message.content),
}
}

const getContentLength = ({ parts }: Content): number =>
parts?.reduce((length, { text }) => length + (text?.length ?? 0), 0) ?? 0

export const getMessagesLength = (contents: Content[]): number =>
contents.reduce((length, content) => length + getContentLength(content), 0)
6 changes: 3 additions & 3 deletions src/shared/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -679,7 +679,7 @@ export const geminiModels = {
maxTokens: 65_535,
contextWindow: 1_048_576,
supportsImages: true,
supportsPromptCache: false,
supportsPromptCache: true,
isPromptCacheOptional: true,
inputPrice: 2.5, // This is the pricing for prompts above 200k tokens.
outputPrice: 15,
Expand All @@ -704,7 +704,7 @@ export const geminiModels = {
maxTokens: 8192,
contextWindow: 1_048_576,
supportsImages: true,
supportsPromptCache: false,
supportsPromptCache: true,
isPromptCacheOptional: true,
inputPrice: 0.1,
outputPrice: 0.4,
Expand Down Expand Up @@ -755,7 +755,7 @@ export const geminiModels = {
maxTokens: 8192,
contextWindow: 1_048_576,
supportsImages: true,
supportsPromptCache: false,
supportsPromptCache: true,
isPromptCacheOptional: true,
inputPrice: 0.15, // This is the pricing for prompts above 128k tokens.
outputPrice: 0.6,
Expand Down