Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 84 additions & 35 deletions src/api/providers/gemini.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,13 @@ import type { ApiStream } from "../transform/stream"
import { BaseProvider } from "./base-provider"

const CACHE_TTL = 5

const CACHE_WRITE_FREQUENCY = 10
const CONTEXT_CACHE_TOKEN_MINIMUM = 4096

type CacheEntry = {
key: string
count: number
tokens?: number
}

type GeminiHandlerOptions = ApiHandlerOptions & {
Expand Down Expand Up @@ -96,7 +97,7 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl
cacheKey &&
contentsLength > 4 * CONTEXT_CACHE_TOKEN_MINIMUM

let cacheWrite = false
let isCacheWriteQueued = false

if (isCacheAvailable) {
const cacheEntry = this.contentCaches.get<CacheEntry>(cacheKey)
Expand All @@ -109,38 +110,10 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl
)
}

if (!this.isCacheBusy) {
this.isCacheBusy = true
const timestamp = Date.now()

this.client.caches
.create({
model,
config: {
contents,
systemInstruction,
ttl: `${CACHE_TTL * 60}s`,
httpOptions: { timeout: 120_000 },
},
})
.then((result) => {
const { name, usageMetadata } = result

if (name) {
this.contentCaches.set<CacheEntry>(cacheKey, { key: name, count: contents.length })
console.log(
`[GeminiHandler] cached ${contents.length} messages (${usageMetadata?.totalTokenCount ?? "-"} tokens) in ${Date.now() - timestamp}ms`,
)
}
})
.catch((error) => {
console.error(`[GeminiHandler] caches.create error`, error)
})
.finally(() => {
this.isCacheBusy = false
})

cacheWrite = true
// If `CACHE_WRITE_FREQUENCY` messages have been appended since the
// last cache write then write a new cache entry.
if (!cacheEntry || (uncachedContent && uncachedContent.length >= CACHE_WRITE_FREQUENCY)) {
isCacheWriteQueued = true
}
}

Expand All @@ -163,6 +136,10 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl

const result = await this.client.models.generateContentStream(params)

if (cacheKey && isCacheWriteQueued) {
this.writeCache({ cacheKey, model, systemInstruction, contents })
}

let lastUsageMetadata: GenerateContentResponseUsageMetadata | undefined

for await (const chunk of result) {
Expand All @@ -178,7 +155,7 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl
if (lastUsageMetadata) {
const inputTokens = lastUsageMetadata.promptTokenCount ?? 0
const outputTokens = lastUsageMetadata.candidatesTokenCount ?? 0
const cacheWriteTokens = cacheWrite ? inputTokens : undefined
const cacheWriteTokens = isCacheWriteQueued ? inputTokens : undefined
const cacheReadTokens = lastUsageMetadata.cachedContentTokenCount
const reasoningTokens = lastUsageMetadata.thoughtsTokenCount

Expand Down Expand Up @@ -338,4 +315,76 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl

return totalCost
}

private writeCache({
cacheKey,
model,
systemInstruction,
contents,
}: {
cacheKey: string
model: string
systemInstruction: string
contents: Content[]
}) {
if (this.isCacheBusy) {
return
}

this.isCacheBusy = true
const timestamp = Date.now()

const previousCacheEntry = this.contentCaches.get<CacheEntry>(cacheKey)

this.client.caches
.create({
model,
config: {
contents,
systemInstruction,
ttl: `${CACHE_TTL * 60}s`,
httpOptions: { timeout: 120_000 },
},
})
.then((result) => {
const { name, usageMetadata } = result

if (name) {
const newCacheEntry: CacheEntry = {
key: name,
count: contents.length,
tokens: usageMetadata?.totalTokenCount,
}

this.contentCaches.set<CacheEntry>(cacheKey, newCacheEntry)

console.log(
`[GeminiHandler] created cache entry ${newCacheEntry.key} -> ${newCacheEntry.count} messages, ${newCacheEntry.tokens} tokens (${Date.now() - timestamp}ms)`,
)

if (previousCacheEntry) {
const timestamp = Date.now()

this.client.caches
.delete({ name: previousCacheEntry.key })
.then(() => {
console.log(
`[GeminiHandler] deleted cache entry ${previousCacheEntry.key} -> ${previousCacheEntry.count} messages, ${previousCacheEntry.tokens} tokens (${Date.now() - timestamp}ms)`,
)
})
.catch((error) => {
console.error(
`[GeminiHandler] failed to delete stale cache entry ${previousCacheEntry.key} -> ${error instanceof Error ? error.message : String(error)}`,
)
})
}
}
})
.catch((error) => {
console.error(`[GeminiHandler] caches.create error`, error)
})
.finally(() => {
this.isCacheBusy = false
})
}
}