Skip to content

Commit 417432f

Browse files
committed
Support tiered pricing
1 parent f08d8a2 commit 417432f

File tree

10 files changed

+238
-47
lines changed

10 files changed

+238
-47
lines changed

src/api/providers/__tests__/gemini.test.ts

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -72,18 +72,15 @@ describe("GeminiHandler", () => {
7272

7373
// Should have 3 chunks: 'Hello', ' world!', and usage info
7474
expect(chunks.length).toBe(3)
75-
expect(chunks[0]).toEqual({
76-
type: "text",
77-
text: "Hello",
78-
})
79-
expect(chunks[1]).toEqual({
80-
type: "text",
81-
text: " world!",
82-
})
75+
expect(chunks[0]).toEqual({ type: "text", text: "Hello" })
76+
expect(chunks[1]).toEqual({ type: "text", text: " world!" })
8377
expect(chunks[2]).toEqual({
8478
type: "usage",
8579
inputTokens: 10,
8680
outputTokens: 5,
81+
cacheReadTokens: undefined,
82+
cacheWriteTokens: undefined,
83+
thinkingTokens: undefined,
8784
})
8885

8986
// Verify the call to generateContentStream

src/api/providers/gemini.ts

Lines changed: 78 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ import { convertAnthropicContentToGemini, convertAnthropicMessageToGemini } from
1414
import type { ApiStream } from "../transform/stream"
1515
import { BaseProvider } from "./base-provider"
1616

17+
const CACHE_TTL = 5
18+
1719
export class GeminiHandler extends BaseProvider implements SingleCompletionHandler {
1820
protected options: ApiHandlerOptions
1921
private client: GoogleGenAI
@@ -31,15 +33,15 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl
3133
messages: Anthropic.Messages.MessageParam[],
3234
taskId?: string,
3335
): ApiStream {
34-
const { id: model, thinkingConfig, maxOutputTokens, supportsPromptCache } = this.getModel()
36+
const { id: model, thinkingConfig, maxOutputTokens, info } = this.getModel()
3537

3638
const contents = messages.map(convertAnthropicMessageToGemini)
3739
let uncachedContent: Content[] | undefined = undefined
3840
let cachedContent: string | undefined = undefined
39-
let cacheWriteTokens: number = 0
41+
let cacheWriteTokens: number | undefined = undefined
4042

4143
// https://ai.google.dev/gemini-api/docs/caching?lang=node
42-
if (supportsPromptCache && taskId) {
44+
if (info.supportsPromptCache && taskId) {
4345
const cacheEntry = this.contentCaches.get(taskId)
4446

4547
if (cacheEntry) {
@@ -49,7 +51,7 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl
4951

5052
const newCacheEntry = await this.client.caches.create({
5153
model,
52-
config: { contents, systemInstruction, ttl: "300s" },
54+
config: { contents, systemInstruction, ttl: `${CACHE_TTL * 60}s` },
5355
})
5456

5557
if (newCacheEntry.name) {
@@ -89,26 +91,31 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl
8991

9092
if (lastUsageMetadata) {
9193
const inputTokens = lastUsageMetadata.promptTokenCount ?? 0
92-
const cachedInputTokens = lastUsageMetadata.cachedContentTokenCount ?? 0
9394
const outputTokens = lastUsageMetadata.candidatesTokenCount ?? 0
95+
const cacheReadTokens = lastUsageMetadata.cachedContentTokenCount
96+
const thinkingTokens = lastUsageMetadata.thoughtsTokenCount
97+
98+
const totalCost = this.calculateCost({
99+
info,
100+
inputTokens,
101+
outputTokens,
102+
cacheWriteTokens,
103+
cacheReadTokens,
104+
})
94105

95106
yield {
96107
type: "usage",
97-
inputTokens: inputTokens - cachedInputTokens,
108+
inputTokens,
98109
outputTokens,
99110
cacheWriteTokens,
100-
cacheReadTokens: cachedInputTokens,
111+
cacheReadTokens,
112+
thinkingTokens,
113+
totalCost,
101114
}
102115
}
103116
}
104117

105-
override getModel(): {
106-
id: GeminiModelId
107-
info: ModelInfo
108-
thinkingConfig?: ThinkingConfig
109-
maxOutputTokens?: number
110-
supportsPromptCache?: boolean
111-
} {
118+
override getModel() {
112119
let id = this.options.apiModelId ? (this.options.apiModelId as GeminiModelId) : geminiDefaultModelId
113120
let info: ModelInfo = geminiModels[id]
114121

@@ -125,7 +132,6 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl
125132
? { thinkingBudget: this.options.modelMaxThinkingTokens }
126133
: undefined,
127134
maxOutputTokens: this.options.modelMaxTokens ?? info.maxTokens ?? undefined,
128-
supportsPromptCache: info.supportsPromptCache,
129135
}
130136
}
131137
}
@@ -135,7 +141,7 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl
135141
info = geminiModels[geminiDefaultModelId]
136142
}
137143

138-
return { id, info, supportsPromptCache: info.supportsPromptCache }
144+
return { id, info }
139145
}
140146

141147
async completePrompt(prompt: string): Promise<string> {
@@ -183,4 +189,60 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl
183189
return super.countTokens(content)
184190
}
185191
}
192+
193+
private calculateCost({
194+
info,
195+
inputTokens,
196+
outputTokens,
197+
cacheWriteTokens,
198+
cacheReadTokens,
199+
}: {
200+
info: ModelInfo
201+
inputTokens: number
202+
outputTokens: number
203+
cacheWriteTokens?: number
204+
cacheReadTokens?: number
205+
}) {
206+
if (!info.inputPrice || !info.outputPrice || !info.cacheWritesPrice || !info.cacheReadsPrice) {
207+
return undefined
208+
}
209+
210+
let inputPrice = info.inputPrice
211+
let outputPrice = info.outputPrice
212+
let cacheWritesPrice = info.cacheWritesPrice
213+
let cacheReadsPrice = info.cacheReadsPrice
214+
215+
// If there's tiered pricing then adjust the input and output token prices
216+
// based on the input tokens used.
217+
if (info.tiers) {
218+
const tier = info.tiers.find((tier) => inputTokens <= tier.contextWindow)
219+
220+
if (tier) {
221+
inputPrice = tier.inputPrice ?? inputPrice
222+
outputPrice = tier.outputPrice ?? outputPrice
223+
cacheWritesPrice = tier.cacheWritesPrice ?? cacheWritesPrice
224+
cacheReadsPrice = tier.cacheReadsPrice ?? cacheReadsPrice
225+
}
226+
}
227+
228+
let inputTokensCost = inputPrice * (inputTokens / 1_000_000)
229+
let outputTokensCost = outputPrice * (outputTokens / 1_000_000)
230+
let cacheWriteCost = 0
231+
let cacheReadCost = 0
232+
233+
// Cache Writes: Charged at the input token cost plus 5 minutes of cache storage.
234+
// Example: Cache write cost = Input token price + (Cache storage price × (5 minutes / 60 minutes))
235+
if (cacheWriteTokens) {
236+
cacheWriteCost = cacheWritesPrice * (cacheWriteTokens / 1_000_000) * (CACHE_TTL / 60)
237+
}
238+
239+
// Cache Reads: Charged at 0.25 × the original input token cost.
240+
if (cacheReadTokens) {
241+
const uncachedReadTokens = inputTokens - cacheReadTokens
242+
cacheReadCost = cacheReadsPrice * (cacheReadTokens / 1_000_000)
243+
inputTokensCost = inputPrice * (uncachedReadTokens / 1_000_000)
244+
}
245+
246+
return inputTokensCost + outputTokensCost + cacheWriteCost + cacheReadCost
247+
}
186248
}

src/api/transform/stream.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,5 +17,6 @@ export interface ApiStreamUsageChunk {
1717
outputTokens: number
1818
cacheWriteTokens?: number
1919
cacheReadTokens?: number
20-
totalCost?: number // openrouter
20+
thinkingTokens?: number
21+
totalCost?: number
2122
}

src/exports/roo-code.d.ts

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,15 @@ type ProviderSettings = {
4646
minTokensPerCachePoint?: number | undefined
4747
maxCachePoints?: number | undefined
4848
cachableFields?: string[] | undefined
49+
tiers?:
50+
| {
51+
contextWindow: number
52+
inputPrice?: number | undefined
53+
outputPrice?: number | undefined
54+
cacheWritesPrice?: number | undefined
55+
cacheReadsPrice?: number | undefined
56+
}[]
57+
| undefined
4958
} | null)
5059
| undefined
5160
glamaApiKey?: string | undefined
@@ -69,6 +78,15 @@ type ProviderSettings = {
6978
minTokensPerCachePoint?: number | undefined
7079
maxCachePoints?: number | undefined
7180
cachableFields?: string[] | undefined
81+
tiers?:
82+
| {
83+
contextWindow: number
84+
inputPrice?: number | undefined
85+
outputPrice?: number | undefined
86+
cacheWritesPrice?: number | undefined
87+
cacheReadsPrice?: number | undefined
88+
}[]
89+
| undefined
7290
} | null)
7391
| undefined
7492
openRouterBaseUrl?: string | undefined
@@ -112,6 +130,15 @@ type ProviderSettings = {
112130
minTokensPerCachePoint?: number | undefined
113131
maxCachePoints?: number | undefined
114132
cachableFields?: string[] | undefined
133+
tiers?:
134+
| {
135+
contextWindow: number
136+
inputPrice?: number | undefined
137+
outputPrice?: number | undefined
138+
cacheWritesPrice?: number | undefined
139+
cacheReadsPrice?: number | undefined
140+
}[]
141+
| undefined
115142
} | null)
116143
| undefined
117144
openAiUseAzure?: boolean | undefined
@@ -158,6 +185,15 @@ type ProviderSettings = {
158185
minTokensPerCachePoint?: number | undefined
159186
maxCachePoints?: number | undefined
160187
cachableFields?: string[] | undefined
188+
tiers?:
189+
| {
190+
contextWindow: number
191+
inputPrice?: number | undefined
192+
outputPrice?: number | undefined
193+
cacheWritesPrice?: number | undefined
194+
cacheReadsPrice?: number | undefined
195+
}[]
196+
| undefined
161197
} | null)
162198
| undefined
163199
requestyApiKey?: string | undefined
@@ -180,6 +216,15 @@ type ProviderSettings = {
180216
minTokensPerCachePoint?: number | undefined
181217
maxCachePoints?: number | undefined
182218
cachableFields?: string[] | undefined
219+
tiers?:
220+
| {
221+
contextWindow: number
222+
inputPrice?: number | undefined
223+
outputPrice?: number | undefined
224+
cacheWritesPrice?: number | undefined
225+
cacheReadsPrice?: number | undefined
226+
}[]
227+
| undefined
183228
} | null)
184229
| undefined
185230
xaiApiKey?: string | undefined

src/exports/types.ts

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,15 @@ type ProviderSettings = {
4747
minTokensPerCachePoint?: number | undefined
4848
maxCachePoints?: number | undefined
4949
cachableFields?: string[] | undefined
50+
tiers?:
51+
| {
52+
contextWindow: number
53+
inputPrice?: number | undefined
54+
outputPrice?: number | undefined
55+
cacheWritesPrice?: number | undefined
56+
cacheReadsPrice?: number | undefined
57+
}[]
58+
| undefined
5059
} | null)
5160
| undefined
5261
glamaApiKey?: string | undefined
@@ -70,6 +79,15 @@ type ProviderSettings = {
7079
minTokensPerCachePoint?: number | undefined
7180
maxCachePoints?: number | undefined
7281
cachableFields?: string[] | undefined
82+
tiers?:
83+
| {
84+
contextWindow: number
85+
inputPrice?: number | undefined
86+
outputPrice?: number | undefined
87+
cacheWritesPrice?: number | undefined
88+
cacheReadsPrice?: number | undefined
89+
}[]
90+
| undefined
7391
} | null)
7492
| undefined
7593
openRouterBaseUrl?: string | undefined
@@ -113,6 +131,15 @@ type ProviderSettings = {
113131
minTokensPerCachePoint?: number | undefined
114132
maxCachePoints?: number | undefined
115133
cachableFields?: string[] | undefined
134+
tiers?:
135+
| {
136+
contextWindow: number
137+
inputPrice?: number | undefined
138+
outputPrice?: number | undefined
139+
cacheWritesPrice?: number | undefined
140+
cacheReadsPrice?: number | undefined
141+
}[]
142+
| undefined
116143
} | null)
117144
| undefined
118145
openAiUseAzure?: boolean | undefined
@@ -159,6 +186,15 @@ type ProviderSettings = {
159186
minTokensPerCachePoint?: number | undefined
160187
maxCachePoints?: number | undefined
161188
cachableFields?: string[] | undefined
189+
tiers?:
190+
| {
191+
contextWindow: number
192+
inputPrice?: number | undefined
193+
outputPrice?: number | undefined
194+
cacheWritesPrice?: number | undefined
195+
cacheReadsPrice?: number | undefined
196+
}[]
197+
| undefined
162198
} | null)
163199
| undefined
164200
requestyApiKey?: string | undefined
@@ -181,6 +217,15 @@ type ProviderSettings = {
181217
minTokensPerCachePoint?: number | undefined
182218
maxCachePoints?: number | undefined
183219
cachableFields?: string[] | undefined
220+
tiers?:
221+
| {
222+
contextWindow: number
223+
inputPrice?: number | undefined
224+
outputPrice?: number | undefined
225+
cacheWritesPrice?: number | undefined
226+
cacheReadsPrice?: number | undefined
227+
}[]
228+
| undefined
184229
} | null)
185230
| undefined
186231
xaiApiKey?: string | undefined

src/schemas/index.ts

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,17 @@ export const modelInfoSchema = z.object({
114114
minTokensPerCachePoint: z.number().optional(),
115115
maxCachePoints: z.number().optional(),
116116
cachableFields: z.array(z.string()).optional(),
117+
tiers: z
118+
.array(
119+
z.object({
120+
contextWindow: z.number(),
121+
inputPrice: z.number().optional(),
122+
outputPrice: z.number().optional(),
123+
cacheWritesPrice: z.number().optional(),
124+
cacheReadsPrice: z.number().optional(),
125+
}),
126+
)
127+
.optional(),
117128
})
118129

119130
export type ModelInfo = z.infer<typeof modelInfoSchema>

0 commit comments

Comments
 (0)