-
Notifications
You must be signed in to change notification settings - Fork 2.6k
feat: add prompt caching support for Kimi K2 on Groq #7324
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,9 +1,22 @@ | ||
| import { type GroqModelId, groqDefaultModelId, groqModels } from "@roo-code/types" | ||
| import { Anthropic } from "@anthropic-ai/sdk" | ||
| import OpenAI from "openai" | ||
|
|
||
| import type { ApiHandlerOptions } from "../../shared/api" | ||
| import type { ApiHandlerCreateMessageMetadata } from "../index" | ||
| import { ApiStream } from "../transform/stream" | ||
| import { convertToOpenAiMessages } from "../transform/openai-format" | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this import still needed? It appears to be unused since the createMessage method is overridden and doesn't call convertToOpenAiMessages. |
||
| import { calculateApiCostOpenAI } from "../../shared/cost" | ||
|
|
||
| import { BaseOpenAiCompatibleProvider } from "./base-openai-compatible-provider" | ||
|
|
||
| // Enhanced usage interface to support Groq's cached token fields | ||
| interface GroqUsage extends OpenAI.CompletionUsage { | ||
| prompt_tokens_details?: { | ||
| cached_tokens?: number | ||
| } | ||
| } | ||
|
|
||
| export class GroqHandler extends BaseOpenAiCompatibleProvider<GroqModelId> { | ||
| constructor(options: ApiHandlerOptions) { | ||
| super({ | ||
|
|
@@ -16,4 +29,61 @@ export class GroqHandler extends BaseOpenAiCompatibleProvider<GroqModelId> { | |
| defaultTemperature: 0.5, | ||
| }) | ||
| } | ||
|
|
||
| override async *createMessage( | ||
| systemPrompt: string, | ||
| messages: Anthropic.Messages.MessageParam[], | ||
| metadata?: ApiHandlerCreateMessageMetadata, | ||
| ): ApiStream { | ||
| const stream = await this.createStream(systemPrompt, messages, metadata) | ||
|
|
||
| for await (const chunk of stream) { | ||
| const delta = chunk.choices[0]?.delta | ||
|
|
||
| if (delta?.content) { | ||
| yield { | ||
| type: "text", | ||
| text: delta.content, | ||
| } | ||
| } | ||
|
|
||
| if (chunk.usage) { | ||
| yield* this.yieldUsage(chunk.usage as GroqUsage) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could we add type validation here to ensure chunk.usage conforms to GroqUsage structure? The type assertion without validation could potentially cause runtime errors if the API response structure changes. |
||
| } | ||
| } | ||
| } | ||
|
|
||
| private async *yieldUsage(usage: GroqUsage | undefined): ApiStream { | ||
| const { info } = this.getModel() | ||
| const inputTokens = usage?.prompt_tokens || 0 | ||
| const outputTokens = usage?.completion_tokens || 0 | ||
|
|
||
| const cacheReadTokens = usage?.prompt_tokens_details?.cached_tokens || 0 | ||
|
|
||
| // Groq does not track cache writes | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could we expand this comment to provide more context? For example: 'Groq does not track cache writes - only cache reads are reported in the API response. This is a limitation of the Groq API as of [date].' |
||
| const cacheWriteTokens = 0 | ||
|
|
||
| // Calculate cost using OpenAI-compatible cost calculation | ||
| const totalCost = calculateApiCostOpenAI(info, inputTokens, outputTokens, cacheWriteTokens, cacheReadTokens) | ||
|
|
||
| // Calculate non-cached input tokens for proper reporting | ||
| const nonCachedInputTokens = Math.max(0, inputTokens - cacheReadTokens - cacheWriteTokens) | ||
|
|
||
| console.log("usage", { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Debug logging should be removed from production code. Could we remove this console.log statement? |
||
| inputTokens: nonCachedInputTokens, | ||
| outputTokens, | ||
| cacheWriteTokens, | ||
| cacheReadTokens, | ||
| totalCost, | ||
| }) | ||
|
|
||
| yield { | ||
| type: "usage", | ||
| inputTokens: nonCachedInputTokens, | ||
| outputTokens, | ||
| cacheWriteTokens, | ||
| cacheReadTokens, | ||
| totalCost, | ||
| } | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider adding edge case tests: