Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/bright-singers-drop.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"roo-cline": patch
---

Enable prompt caching for Gemini (with some improvements)
105 changes: 78 additions & 27 deletions src/api/providers/gemini.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,40 @@ import {
type GenerateContentResponseUsageMetadata,
type GenerateContentParameters,
type Content,
CreateCachedContentConfig,
} from "@google/genai"
import NodeCache from "node-cache"

import { SingleCompletionHandler } from "../"
import type { ApiHandlerOptions, GeminiModelId, ModelInfo } from "../../shared/api"
import { geminiDefaultModelId, geminiModels } from "../../shared/api"
import { convertAnthropicContentToGemini, convertAnthropicMessageToGemini } from "../transform/gemini-format"
import {
convertAnthropicContentToGemini,
convertAnthropicMessageToGemini,
getMessagesLength,
} from "../transform/gemini-format"
import type { ApiStream } from "../transform/stream"
import { BaseProvider } from "./base-provider"

const CACHE_TTL = 5

const CONTEXT_CACHE_TOKEN_MINIMUM = 4096

type CacheEntry = {
key: string
count: number
}

export class GeminiHandler extends BaseProvider implements SingleCompletionHandler {
protected options: ApiHandlerOptions
private client: GoogleGenAI
private contentCaches: Map<string, { key: string; count: number }>
private contentCaches: NodeCache

constructor(options: ApiHandlerOptions) {
super()
this.options = options
this.client = new GoogleGenAI({ apiKey: options.geminiApiKey ?? "not-provided" })
this.contentCaches = new Map()
this.contentCaches = new NodeCache({ stdTTL: 5 * 60, checkperiod: 5 * 60 })
}

async *createMessage(
Expand All @@ -35,36 +48,72 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl
const { id: model, thinkingConfig, maxOutputTokens, info } = this.getModel()

const contents = messages.map(convertAnthropicMessageToGemini)
const contentsLength = systemInstruction.length + getMessagesLength(contents)

let uncachedContent: Content[] | undefined = undefined
let cachedContent: string | undefined = undefined
let cacheWriteTokens: number | undefined = undefined

// The minimum input token count for context caching is 4,096.
// For a basic appoximation we assume 4 characters per token.
// We can use tiktoken eventually to get a more accurat token count.
// https://ai.google.dev/gemini-api/docs/caching?lang=node
// if (info.supportsPromptCache && cacheKey) {
// const cacheEntry = this.contentCaches.get(cacheKey)
// https://ai.google.dev/gemini-api/docs/tokens?lang=node
const isCacheAvailable =
info.supportsPromptCache &&
this.options.promptCachingEnabled &&
cacheKey &&
contentsLength > 4 * CONTEXT_CACHE_TOKEN_MINIMUM

console.log(`[GeminiHandler] isCacheAvailable=${isCacheAvailable}, contentsLength=${contentsLength}`)

if (isCacheAvailable) {
const cacheEntry = this.contentCaches.get<CacheEntry>(cacheKey)

if (cacheEntry) {
uncachedContent = contents.slice(cacheEntry.count, contents.length)
cachedContent = cacheEntry.key
console.log(
`[GeminiHandler] using ${cacheEntry.count} cached messages (${cacheEntry.key}) and ${uncachedContent.length} uncached messages`,
)
}

const timestamp = Date.now()

// if (cacheEntry) {
// uncachedContent = contents.slice(cacheEntry.count, contents.length)
// cachedContent = cacheEntry.key
// }
const config: CreateCachedContentConfig = {
contents,
systemInstruction,
ttl: `${CACHE_TTL * 60}s`,
httpOptions: { timeout: 10_000 },
}

// const newCacheEntry = await this.client.caches.create({
// model,
// config: { contents, systemInstruction, ttl: `${CACHE_TTL * 60}s` },
// })
this.client.caches
.create({ model, config })
.then((result) => {
console.log(`[GeminiHandler] caches.create result -> ${JSON.stringify(result)}`)
const { name, usageMetadata } = result

if (name) {
this.contentCaches.set<CacheEntry>(cacheKey, { key: name, count: contents.length })
cacheWriteTokens = usageMetadata?.totalTokenCount ?? 0
console.log(
`[GeminiHandler] cached ${contents.length} messages (${cacheWriteTokens} tokens) in ${Date.now() - timestamp}ms`,
)
}
})
.catch((error) => {
console.error(`[GeminiHandler] caches.create error`, error)
})
}

// if (newCacheEntry.name) {
// this.contentCaches.set(cacheKey, { key: newCacheEntry.name, count: contents.length })
// cacheWriteTokens = newCacheEntry.usageMetadata?.totalTokenCount ?? 0
// }
// }
const isCacheUsed = !!cachedContent

const params: GenerateContentParameters = {
model,
contents: uncachedContent ?? contents,
config: {
cachedContent,
systemInstruction: cachedContent ? undefined : systemInstruction,
systemInstruction: isCacheUsed ? undefined : systemInstruction,
httpOptions: this.options.googleGeminiBaseUrl
? { baseUrl: this.options.googleGeminiBaseUrl }
: undefined,
Expand Down Expand Up @@ -94,13 +143,15 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl
const cacheReadTokens = lastUsageMetadata.cachedContentTokenCount
const reasoningTokens = lastUsageMetadata.thoughtsTokenCount

// const totalCost = this.calculateCost({
// info,
// inputTokens,
// outputTokens,
// cacheWriteTokens,
// cacheReadTokens,
// })
const totalCost = isCacheUsed
? this.calculateCost({
info,
inputTokens,
outputTokens,
cacheWriteTokens,
cacheReadTokens,
})
: undefined

yield {
type: "usage",
Expand All @@ -109,7 +160,7 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl
cacheWriteTokens,
cacheReadTokens,
reasoningTokens,
// totalCost,
totalCost,
}
}
}
Expand Down
6 changes: 6 additions & 0 deletions src/api/transform/gemini-format.ts
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,9 @@ export function convertAnthropicMessageToGemini(message: Anthropic.Messages.Mess
parts: convertAnthropicContentToGemini(message.content),
}
}

const getContentLength = ({ parts }: Content): number =>
parts?.reduce((length, { text }) => length + (text?.length ?? 0), 0) ?? 0

export const getMessagesLength = (contents: Content[]): number =>
contents.reduce((length, content) => length + getContentLength(content), 0)
6 changes: 3 additions & 3 deletions src/shared/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -679,7 +679,7 @@ export const geminiModels = {
maxTokens: 65_535,
contextWindow: 1_048_576,
supportsImages: true,
supportsPromptCache: false,
supportsPromptCache: true,
isPromptCacheOptional: true,
inputPrice: 2.5, // This is the pricing for prompts above 200k tokens.
outputPrice: 15,
Expand All @@ -704,7 +704,7 @@ export const geminiModels = {
maxTokens: 8192,
contextWindow: 1_048_576,
supportsImages: true,
supportsPromptCache: false,
supportsPromptCache: true,
isPromptCacheOptional: true,
inputPrice: 0.1,
outputPrice: 0.4,
Expand Down Expand Up @@ -755,7 +755,7 @@ export const geminiModels = {
maxTokens: 8192,
contextWindow: 1_048_576,
supportsImages: true,
supportsPromptCache: false,
supportsPromptCache: true,
isPromptCacheOptional: true,
inputPrice: 0.15, // This is the pricing for prompts above 128k tokens.
outputPrice: 0.6,
Expand Down