Skip to content

Commit faab314

Browse files
daniel-lxsCline Contributors
andauthored
feat: add prompt caching support for Kimi K2 on Groq (#7324)
Ported from upstream Cline repository PR #5697 Original PR: cline/cline#5697 - Added GroqUsage interface to handle cached token fields - Implemented proper cost calculation with cache read discounts - Enabled prompt caching for Kimi K2 model with 50% discount on cached tokens - Updated tests to verify caching functionality Co-authored-by: Cline Contributors <[email protected]>
1 parent f14e6ac commit faab314

File tree

3 files changed

+119
-2
lines changed

3 files changed

+119
-2
lines changed

packages/types/src/providers/groq.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,9 +94,10 @@ export const groqModels = {
9494
maxTokens: 16384,
9595
contextWindow: 131072,
9696
supportsImages: false,
97-
supportsPromptCache: false,
97+
supportsPromptCache: true,
9898
inputPrice: 1.0,
9999
outputPrice: 3.0,
100+
cacheReadsPrice: 0.5, // 50% discount for cached input tokens
100101
description: "Moonshot AI Kimi K2 Instruct 1T model, 128K context.",
101102
},
102103
"openai/gpt-oss-120b": {

src/api/providers/__tests__/groq.spec.ts

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,53 @@ describe("GroqHandler", () => {
108108
const firstChunk = await stream.next()
109109

110110
expect(firstChunk.done).toBe(false)
111-
expect(firstChunk.value).toEqual({ type: "usage", inputTokens: 10, outputTokens: 20 })
111+
expect(firstChunk.value).toMatchObject({
112+
type: "usage",
113+
inputTokens: 10,
114+
outputTokens: 20,
115+
cacheWriteTokens: 0,
116+
cacheReadTokens: 0,
117+
})
118+
// Check that totalCost is a number (we don't need to test the exact value as that's tested in cost.spec.ts)
119+
expect(typeof firstChunk.value.totalCost).toBe("number")
120+
})
121+
122+
it("createMessage should handle cached tokens in usage data", async () => {
123+
mockCreate.mockImplementationOnce(() => {
124+
return {
125+
[Symbol.asyncIterator]: () => ({
126+
next: vitest
127+
.fn()
128+
.mockResolvedValueOnce({
129+
done: false,
130+
value: {
131+
choices: [{ delta: {} }],
132+
usage: {
133+
prompt_tokens: 100,
134+
completion_tokens: 50,
135+
prompt_tokens_details: {
136+
cached_tokens: 30,
137+
},
138+
},
139+
},
140+
})
141+
.mockResolvedValueOnce({ done: true }),
142+
}),
143+
}
144+
})
145+
146+
const stream = handler.createMessage("system prompt", [])
147+
const firstChunk = await stream.next()
148+
149+
expect(firstChunk.done).toBe(false)
150+
expect(firstChunk.value).toMatchObject({
151+
type: "usage",
152+
inputTokens: 70, // 100 total - 30 cached
153+
outputTokens: 50,
154+
cacheWriteTokens: 0,
155+
cacheReadTokens: 30,
156+
})
157+
expect(typeof firstChunk.value.totalCost).toBe("number")
112158
})
113159

114160
it("createMessage should pass correct parameters to Groq client", async () => {

src/api/providers/groq.ts

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,22 @@
11
import { type GroqModelId, groqDefaultModelId, groqModels } from "@roo-code/types"
2+
import { Anthropic } from "@anthropic-ai/sdk"
3+
import OpenAI from "openai"
24

35
import type { ApiHandlerOptions } from "../../shared/api"
6+
import type { ApiHandlerCreateMessageMetadata } from "../index"
7+
import { ApiStream } from "../transform/stream"
8+
import { convertToOpenAiMessages } from "../transform/openai-format"
9+
import { calculateApiCostOpenAI } from "../../shared/cost"
410

511
import { BaseOpenAiCompatibleProvider } from "./base-openai-compatible-provider"
612

13+
// Enhanced usage interface to support Groq's cached token fields
14+
interface GroqUsage extends OpenAI.CompletionUsage {
15+
prompt_tokens_details?: {
16+
cached_tokens?: number
17+
}
18+
}
19+
720
export class GroqHandler extends BaseOpenAiCompatibleProvider<GroqModelId> {
821
constructor(options: ApiHandlerOptions) {
922
super({
@@ -16,4 +29,61 @@ export class GroqHandler extends BaseOpenAiCompatibleProvider<GroqModelId> {
1629
defaultTemperature: 0.5,
1730
})
1831
}
32+
33+
override async *createMessage(
34+
systemPrompt: string,
35+
messages: Anthropic.Messages.MessageParam[],
36+
metadata?: ApiHandlerCreateMessageMetadata,
37+
): ApiStream {
38+
const stream = await this.createStream(systemPrompt, messages, metadata)
39+
40+
for await (const chunk of stream) {
41+
const delta = chunk.choices[0]?.delta
42+
43+
if (delta?.content) {
44+
yield {
45+
type: "text",
46+
text: delta.content,
47+
}
48+
}
49+
50+
if (chunk.usage) {
51+
yield* this.yieldUsage(chunk.usage as GroqUsage)
52+
}
53+
}
54+
}
55+
56+
private async *yieldUsage(usage: GroqUsage | undefined): ApiStream {
57+
const { info } = this.getModel()
58+
const inputTokens = usage?.prompt_tokens || 0
59+
const outputTokens = usage?.completion_tokens || 0
60+
61+
const cacheReadTokens = usage?.prompt_tokens_details?.cached_tokens || 0
62+
63+
// Groq does not track cache writes
64+
const cacheWriteTokens = 0
65+
66+
// Calculate cost using OpenAI-compatible cost calculation
67+
const totalCost = calculateApiCostOpenAI(info, inputTokens, outputTokens, cacheWriteTokens, cacheReadTokens)
68+
69+
// Calculate non-cached input tokens for proper reporting
70+
const nonCachedInputTokens = Math.max(0, inputTokens - cacheReadTokens - cacheWriteTokens)
71+
72+
console.log("usage", {
73+
inputTokens: nonCachedInputTokens,
74+
outputTokens,
75+
cacheWriteTokens,
76+
cacheReadTokens,
77+
totalCost,
78+
})
79+
80+
yield {
81+
type: "usage",
82+
inputTokens: nonCachedInputTokens,
83+
outputTokens,
84+
cacheWriteTokens,
85+
cacheReadTokens,
86+
totalCost,
87+
}
88+
}
1989
}

0 commit comments

Comments
 (0)