Skip to content

Commit a525d6d

Browse files
authored
Adding Caching to gemini provider (RooCodeInc#3072)
1 parent 59dd323 commit a525d6d

File tree

1 file changed

+70
-10
lines changed

1 file changed

+70
-10
lines changed

src/api/providers/gemini.ts

Lines changed: 70 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,24 @@
11
import type { Anthropic } from "@anthropic-ai/sdk"
22
// Restore GenerateContentConfig import and add GenerateContentResponseUsageMetadata
3-
import {
4-
GoogleGenAI,
5-
type GenerationConfig,
6-
type Content,
7-
type GenerateContentConfig,
8-
type GenerateContentResponseUsageMetadata,
9-
} from "@google/genai"
3+
import { GoogleGenAI, type Content, type GenerateContentConfig, type GenerateContentResponseUsageMetadata } from "@google/genai"
104
import { withRetry } from "../retry"
115
import { ApiHandler } from "../"
126
import { ApiHandlerOptions, geminiDefaultModelId, GeminiModelId, geminiModels, ModelInfo } from "@shared/api"
137
import { convertAnthropicMessageToGemini } from "../transform/gemini-format"
148
import { ApiStream } from "../transform/stream"
159

10+
// Define a default TTL for the cache (e.g., 1 hour in seconds)
11+
const DEFAULT_CACHE_TTL_SECONDS = 3600
12+
1613
export class GeminiHandler implements ApiHandler {
1714
private options: ApiHandlerOptions
1815
private client: GoogleGenAI // Updated client type
1916

17+
// Internal state for caching
18+
private cacheName: string | null = null
19+
private cacheExpireTime: number | null = null
20+
private isFirstApiCall = true
21+
2022
constructor(options: ApiHandlerOptions) {
2123
if (!options.geminiApiKey) {
2224
throw new Error("API key is required for Google Gemini")
@@ -30,18 +32,45 @@ export class GeminiHandler implements ApiHandler {
3032
async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream {
3133
const { id: modelId, info: modelInfo } = this.getModel()
3234

35+
// --- Cache Handling Logic ---
36+
const isCacheValid = this.cacheName && this.cacheExpireTime && Date.now() < this.cacheExpireTime
37+
let useCache = !this.isFirstApiCall && isCacheValid
38+
39+
if (this.isFirstApiCall && !isCacheValid && systemPrompt) {
40+
// It's the first call, no valid cache exists, and we have a system prompt. Attempt cache creation.
41+
this.isFirstApiCall = false
42+
43+
// Minimum token check heuristic (simple length check for now, could be improved)
44+
// Gemini requires minimum 4096 tokens. A simple length check isn't accurate but avoids complex token counting here.
45+
// Let's assume a generous average of 4 chars/token. 4096 tokens * 4 chars/token = 16384 chars.
46+
const MIN_SYSTEM_PROMPT_LENGTH_FOR_CACHE = 16384
47+
if (systemPrompt.length >= MIN_SYSTEM_PROMPT_LENGTH_FOR_CACHE) {
48+
// Start cache creation asynchronously, don't block the main request
49+
this.createCacheInBackground(modelId, systemPrompt)
50+
}
51+
// Proceed with the first request *without* using the cache, as it's being created.
52+
useCache = false
53+
} else if (!isCacheValid && this.cacheName) {
54+
// Cache exists but has expired
55+
this.cacheName = null
56+
this.cacheExpireTime = null
57+
useCache = false
58+
}
59+
// --- End Cache Handling Logic ---
60+
3361
// Re-implement thinking budget logic based on new SDK structure
3462
const thinkingBudget = this.options.thinkingBudgetTokens ?? 0
3563
const maxBudget = modelInfo.thinkingConfig?.maxBudget ?? 0
3664

3765
// port add baseUrl configuration for gemini api requests (#2843)
3866
const httpOptions = this.options.geminiBaseUrl ? { baseUrl: this.options.geminiBaseUrl } : undefined
3967

40-
// Base generation config - Restore type and systemInstruction
68+
// Base generation config - Conditionally include systemInstruction based on cache usage
4169
const generationConfig: GenerateContentConfig = {
4270
httpOptions,
4371
temperature: 0, // Default temperature
44-
systemInstruction: systemPrompt, // System prompt belongs here
72+
// Only include systemInstruction if NOT using the cache
73+
...(useCache ? {} : { systemInstruction: systemPrompt }),
4574
}
4675

4776
// Convert messages to the format expected by @google/genai
@@ -64,7 +93,11 @@ export class GeminiHandler implements ApiHandler {
6493
const result = await this.client.models.generateContentStream({
6594
model: modelId, // Pass model ID directly
6695
contents,
67-
config: requestConfig, // Pass the combined config (which includes systemInstruction)
96+
// Add cachedContent if using the cache
97+
config: {
98+
...requestConfig,
99+
...(useCache ? { cachedContent: this.cacheName! } : {}),
100+
},
68101
})
69102

70103
// Declare variable to hold the last usage metadata found
@@ -88,7 +121,34 @@ export class GeminiHandler implements ApiHandler {
88121
type: "usage",
89122
inputTokens: lastUsageMetadata.promptTokenCount ?? 0,
90123
outputTokens: lastUsageMetadata.candidatesTokenCount ?? 0,
124+
cacheWriteTokens: lastUsageMetadata.cachedContentTokenCount ?? 0,
125+
cacheReadTokens: useCache ? (lastUsageMetadata.promptTokenCount ?? 0) : 0, // If cache used, prompt tokens are read from cache
126+
}
127+
}
128+
}
129+
130+
private async createCacheInBackground(modelId: string, systemInstruction: string): Promise<void> {
131+
try {
132+
const cache = await this.client.caches.create({
133+
model: modelId,
134+
config: {
135+
systemInstruction: systemInstruction,
136+
ttl: `${DEFAULT_CACHE_TTL_SECONDS}s`,
137+
},
138+
})
139+
140+
if (cache?.name) {
141+
this.cacheName = cache.name
142+
// Calculate expiry timestamp using the default TTL, as the response object might not contain it directly.
143+
this.cacheExpireTime = Date.now() + DEFAULT_CACHE_TTL_SECONDS * 1000
144+
} else {
145+
console.warn("Gemini cache creation call succeeded but returned no cache name.")
91146
}
147+
} catch (error) {
148+
console.error("Failed to create Gemini cache in background:", error)
149+
// Reset state if creation failed definitively
150+
this.cacheName = null
151+
this.cacheExpireTime = null
92152
}
93153
}
94154

0 commit comments

Comments
 (0)