@@ -14,6 +14,8 @@ import { convertAnthropicContentToGemini, convertAnthropicMessageToGemini } from
1414import type { ApiStream } from "../transform/stream"
1515import { BaseProvider } from "./base-provider"
1616
17+ const CACHE_TTL = 5
18+
1719export class GeminiHandler extends BaseProvider implements SingleCompletionHandler {
1820 protected options : ApiHandlerOptions
1921 private client : GoogleGenAI
@@ -31,15 +33,15 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl
3133 messages : Anthropic . Messages . MessageParam [ ] ,
3234 taskId ?: string ,
3335 ) : ApiStream {
34- const { id : model , thinkingConfig, maxOutputTokens, supportsPromptCache } = this . getModel ( )
36+ const { id : model , thinkingConfig, maxOutputTokens, info } = this . getModel ( )
3537
3638 const contents = messages . map ( convertAnthropicMessageToGemini )
3739 let uncachedContent : Content [ ] | undefined = undefined
3840 let cachedContent : string | undefined = undefined
39- let cacheWriteTokens : number = 0
41+ let cacheWriteTokens : number | undefined = undefined
4042
4143 // https://ai.google.dev/gemini-api/docs/caching?lang=node
42- if ( supportsPromptCache && taskId ) {
44+ if ( info . supportsPromptCache && taskId ) {
4345 const cacheEntry = this . contentCaches . get ( taskId )
4446
4547 if ( cacheEntry ) {
@@ -49,7 +51,7 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl
4951
5052 const newCacheEntry = await this . client . caches . create ( {
5153 model,
52- config : { contents, systemInstruction, ttl : "300s" } ,
54+ config : { contents, systemInstruction, ttl : ` ${ CACHE_TTL * 60 } s` } ,
5355 } )
5456
5557 if ( newCacheEntry . name ) {
@@ -89,26 +91,31 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl
8991
9092 if ( lastUsageMetadata ) {
9193 const inputTokens = lastUsageMetadata . promptTokenCount ?? 0
92- const cachedInputTokens = lastUsageMetadata . cachedContentTokenCount ?? 0
9394 const outputTokens = lastUsageMetadata . candidatesTokenCount ?? 0
95+ const cacheReadTokens = lastUsageMetadata . cachedContentTokenCount
96+ const thinkingTokens = lastUsageMetadata . thoughtsTokenCount
97+
98+ const totalCost = this . calculateCost ( {
99+ info,
100+ inputTokens,
101+ outputTokens,
102+ cacheWriteTokens,
103+ cacheReadTokens,
104+ } )
94105
95106 yield {
96107 type : "usage" ,
97- inputTokens : inputTokens - cachedInputTokens ,
108+ inputTokens,
98109 outputTokens,
99110 cacheWriteTokens,
100- cacheReadTokens : cachedInputTokens ,
111+ cacheReadTokens,
112+ thinkingTokens,
113+ totalCost,
101114 }
102115 }
103116 }
104117
105- override getModel ( ) : {
106- id : GeminiModelId
107- info : ModelInfo
108- thinkingConfig ?: ThinkingConfig
109- maxOutputTokens ?: number
110- supportsPromptCache ?: boolean
111- } {
118+ override getModel ( ) {
112119 let id = this . options . apiModelId ? ( this . options . apiModelId as GeminiModelId ) : geminiDefaultModelId
113120 let info : ModelInfo = geminiModels [ id ]
114121
@@ -125,7 +132,6 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl
125132 ? { thinkingBudget : this . options . modelMaxThinkingTokens }
126133 : undefined ,
127134 maxOutputTokens : this . options . modelMaxTokens ?? info . maxTokens ?? undefined ,
128- supportsPromptCache : info . supportsPromptCache ,
129135 }
130136 }
131137 }
@@ -135,7 +141,7 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl
135141 info = geminiModels [ geminiDefaultModelId ]
136142 }
137143
138- return { id, info, supportsPromptCache : info . supportsPromptCache }
144+ return { id, info }
139145 }
140146
141147 async completePrompt ( prompt : string ) : Promise < string > {
@@ -183,4 +189,60 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl
183189 return super . countTokens ( content )
184190 }
185191 }
192+
193+ private calculateCost ( {
194+ info,
195+ inputTokens,
196+ outputTokens,
197+ cacheWriteTokens,
198+ cacheReadTokens,
199+ } : {
200+ info : ModelInfo
201+ inputTokens : number
202+ outputTokens : number
203+ cacheWriteTokens ?: number
204+ cacheReadTokens ?: number
205+ } ) {
206+ if ( ! info . inputPrice || ! info . outputPrice || ! info . cacheWritesPrice || ! info . cacheReadsPrice ) {
207+ return undefined
208+ }
209+
210+ let inputPrice = info . inputPrice
211+ let outputPrice = info . outputPrice
212+ let cacheWritesPrice = info . cacheWritesPrice
213+ let cacheReadsPrice = info . cacheReadsPrice
214+
215+ // If there's tiered pricing then adjust the input and output token prices
216+ // based on the input tokens used.
217+ if ( info . tiers ) {
218+ const tier = info . tiers . find ( ( tier ) => inputTokens <= tier . contextWindow )
219+
220+ if ( tier ) {
221+ inputPrice = tier . inputPrice ?? inputPrice
222+ outputPrice = tier . outputPrice ?? outputPrice
223+ cacheWritesPrice = tier . cacheWritesPrice ?? cacheWritesPrice
224+ cacheReadsPrice = tier . cacheReadsPrice ?? cacheReadsPrice
225+ }
226+ }
227+
228+ let inputTokensCost = inputPrice * ( inputTokens / 1_000_000 )
229+ let outputTokensCost = outputPrice * ( outputTokens / 1_000_000 )
230+ let cacheWriteCost = 0
231+ let cacheReadCost = 0
232+
233+ // Cache Writes: Charged at the input token cost plus 5 minutes of cache storage.
234+ // Example: Cache write cost = Input token price + (Cache storage price × (5 minutes / 60 minutes))
235+ if ( cacheWriteTokens ) {
236+ cacheWriteCost = cacheWritesPrice * ( cacheWriteTokens / 1_000_000 ) * ( CACHE_TTL / 60 )
237+ }
238+
239+ // Cache Reads: Charged at 0.25 × the original input token cost.
240+ if ( cacheReadTokens ) {
241+ const uncachedReadTokens = inputTokens - cacheReadTokens
242+ cacheReadCost = cacheReadsPrice * ( cacheReadTokens / 1_000_000 )
243+ inputTokensCost = inputPrice * ( uncachedReadTokens / 1_000_000 )
244+ }
245+
246+ return inputTokensCost + outputTokensCost + cacheWriteCost + cacheReadCost
247+ }
186248}
0 commit comments