Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 69 additions & 27 deletions src/api/providers/gemini.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,38 @@ import {
type GenerateContentResponseUsageMetadata,
type GenerateContentParameters,
type Content,
CreateCachedContentConfig,
} from "@google/genai"
import NodeCache from "node-cache"

import { SingleCompletionHandler } from "../"
import type { ApiHandlerOptions, GeminiModelId, ModelInfo } from "../../shared/api"
import { geminiDefaultModelId, geminiModels } from "../../shared/api"
import { convertAnthropicContentToGemini, convertAnthropicMessageToGemini } from "../transform/gemini-format"
import {
convertAnthropicContentToGemini,
convertAnthropicMessageToGemini,
getMessagesLength,
} from "../transform/gemini-format"
import type { ApiStream } from "../transform/stream"
import { BaseProvider } from "./base-provider"

const CACHE_TTL = 5

type CacheEntry = {
key: string
count: number
}

export class GeminiHandler extends BaseProvider implements SingleCompletionHandler {
protected options: ApiHandlerOptions
private client: GoogleGenAI
private contentCaches: Map<string, { key: string; count: number }>
private contentCaches: NodeCache

constructor(options: ApiHandlerOptions) {
super()
this.options = options
this.client = new GoogleGenAI({ apiKey: options.geminiApiKey ?? "not-provided" })
this.contentCaches = new Map()
this.contentCaches = new NodeCache({ stdTTL: 5 * 60, checkperiod: 5 * 60 })
}

async *createMessage(
Expand All @@ -35,36 +46,65 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl
const { id: model, thinkingConfig, maxOutputTokens, info } = this.getModel()

const contents = messages.map(convertAnthropicMessageToGemini)
const contentsLength = systemInstruction.length + getMessagesLength(contents)

let uncachedContent: Content[] | undefined = undefined
let cachedContent: string | undefined = undefined
let cacheWriteTokens: number | undefined = undefined

const isCacheAvailable =
info.supportsPromptCache && this.options.promptCachingEnabled && cacheKey && contentsLength > 16_384

console.log(`[GeminiHandler] isCacheAvailable=${isCacheAvailable}, contentsLength=${contentsLength}`)

// https://ai.google.dev/gemini-api/docs/caching?lang=node
// if (info.supportsPromptCache && cacheKey) {
// const cacheEntry = this.contentCaches.get(cacheKey)
if (isCacheAvailable) {
const cacheEntry = this.contentCaches.get<CacheEntry>(cacheKey)

if (cacheEntry) {
uncachedContent = contents.slice(cacheEntry.count, contents.length)
cachedContent = cacheEntry.key
console.log(
`[GeminiHandler] using ${cacheEntry.count} cached messages (${cacheEntry.key}) and ${uncachedContent.length} uncached messages`,
)
}

// if (cacheEntry) {
// uncachedContent = contents.slice(cacheEntry.count, contents.length)
// cachedContent = cacheEntry.key
// }
const timestamp = Date.now()

// const newCacheEntry = await this.client.caches.create({
// model,
// config: { contents, systemInstruction, ttl: `${CACHE_TTL * 60}s` },
// })
const config: CreateCachedContentConfig = {
contents,
systemInstruction,
ttl: `${CACHE_TTL * 60}s`,
httpOptions: { timeout: 10_000 },
}

this.client.caches
.create({ model, config })
.then((result) => {
console.log(`[GeminiHandler] caches.create result -> ${JSON.stringify(result)}`)
const { name, usageMetadata } = result

if (name) {
this.contentCaches.set<CacheEntry>(cacheKey, { key: name, count: contents.length })
cacheWriteTokens = usageMetadata?.totalTokenCount ?? 0
console.log(
`[GeminiHandler] cached ${contents.length} messages (${cacheWriteTokens} tokens) in ${Date.now() - timestamp}ms`,
)
}
})
.catch((error) => {
console.error(`[GeminiHandler] caches.create error`, error)
})
}

// if (newCacheEntry.name) {
// this.contentCaches.set(cacheKey, { key: newCacheEntry.name, count: contents.length })
// cacheWriteTokens = newCacheEntry.usageMetadata?.totalTokenCount ?? 0
// }
// }
const isCacheUsed = !!cachedContent

const params: GenerateContentParameters = {
model,
contents: uncachedContent ?? contents,
config: {
cachedContent,
systemInstruction: cachedContent ? undefined : systemInstruction,
systemInstruction: isCacheUsed ? undefined : systemInstruction,
httpOptions: this.options.googleGeminiBaseUrl
? { baseUrl: this.options.googleGeminiBaseUrl }
: undefined,
Expand Down Expand Up @@ -94,13 +134,15 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl
const cacheReadTokens = lastUsageMetadata.cachedContentTokenCount
const reasoningTokens = lastUsageMetadata.thoughtsTokenCount

// const totalCost = this.calculateCost({
// info,
// inputTokens,
// outputTokens,
// cacheWriteTokens,
// cacheReadTokens,
// })
const totalCost = isCacheUsed
? this.calculateCost({
info,
inputTokens,
outputTokens,
cacheWriteTokens,
cacheReadTokens,
})
: undefined

yield {
type: "usage",
Expand All @@ -109,7 +151,7 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl
cacheWriteTokens,
cacheReadTokens,
reasoningTokens,
// totalCost,
totalCost,
}
}
}
Expand Down
6 changes: 6 additions & 0 deletions src/api/transform/gemini-format.ts
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,9 @@ export function convertAnthropicMessageToGemini(message: Anthropic.Messages.Mess
parts: convertAnthropicContentToGemini(message.content),
}
}

const getContentLength = ({ parts }: Content): number =>
parts?.reduce((length, { text }) => length + (text?.length ?? 0), 0) ?? 0

export const getMessagesLength = (contents: Content[]): number =>
contents.reduce((length, content) => length + getContentLength(content), 0)
6 changes: 3 additions & 3 deletions src/shared/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -679,7 +679,7 @@ export const geminiModels = {
maxTokens: 65_535,
contextWindow: 1_048_576,
supportsImages: true,
supportsPromptCache: false,
supportsPromptCache: true,
isPromptCacheOptional: true,
inputPrice: 2.5, // This is the pricing for prompts above 200k tokens.
outputPrice: 15,
Expand All @@ -704,7 +704,7 @@ export const geminiModels = {
maxTokens: 8192,
contextWindow: 1_048_576,
supportsImages: true,
supportsPromptCache: false,
supportsPromptCache: true,
isPromptCacheOptional: true,
inputPrice: 0.1,
outputPrice: 0.4,
Expand Down Expand Up @@ -755,7 +755,7 @@ export const geminiModels = {
maxTokens: 8192,
contextWindow: 1_048_576,
supportsImages: true,
supportsPromptCache: false,
supportsPromptCache: true,
isPromptCacheOptional: true,
inputPrice: 0.15, // This is the pricing for prompts above 128k tokens.
outputPrice: 0.6,
Expand Down