diff --git a/packages/types/src/providers/groq.ts b/packages/types/src/providers/groq.ts index cab0c69900..feb1777ce3 100644 --- a/packages/types/src/providers/groq.ts +++ b/packages/types/src/providers/groq.ts @@ -94,9 +94,10 @@ export const groqModels = { maxTokens: 16384, contextWindow: 131072, supportsImages: false, - supportsPromptCache: false, + supportsPromptCache: true, inputPrice: 1.0, outputPrice: 3.0, + cacheReadsPrice: 0.5, // 50% discount for cached input tokens description: "Moonshot AI Kimi K2 Instruct 1T model, 128K context.", }, "openai/gpt-oss-120b": { diff --git a/src/api/providers/__tests__/groq.spec.ts b/src/api/providers/__tests__/groq.spec.ts index a943e84daa..52846617f4 100644 --- a/src/api/providers/__tests__/groq.spec.ts +++ b/src/api/providers/__tests__/groq.spec.ts @@ -108,7 +108,53 @@ describe("GroqHandler", () => { const firstChunk = await stream.next() expect(firstChunk.done).toBe(false) - expect(firstChunk.value).toEqual({ type: "usage", inputTokens: 10, outputTokens: 20 }) + expect(firstChunk.value).toMatchObject({ + type: "usage", + inputTokens: 10, + outputTokens: 20, + cacheWriteTokens: 0, + cacheReadTokens: 0, + }) + // Check that totalCost is a number (we don't need to test the exact value as that's tested in cost.spec.ts) + expect(typeof firstChunk.value.totalCost).toBe("number") + }) + + it("createMessage should handle cached tokens in usage data", async () => { + mockCreate.mockImplementationOnce(() => { + return { + [Symbol.asyncIterator]: () => ({ + next: vitest + .fn() + .mockResolvedValueOnce({ + done: false, + value: { + choices: [{ delta: {} }], + usage: { + prompt_tokens: 100, + completion_tokens: 50, + prompt_tokens_details: { + cached_tokens: 30, + }, + }, + }, + }) + .mockResolvedValueOnce({ done: true }), + }), + } + }) + + const stream = handler.createMessage("system prompt", []) + const firstChunk = await stream.next() + + expect(firstChunk.done).toBe(false) + expect(firstChunk.value).toMatchObject({ + type: "usage", + inputTokens: 70, // 100 total - 30 cached + outputTokens: 50, + cacheWriteTokens: 0, + cacheReadTokens: 30, + }) + expect(typeof firstChunk.value.totalCost).toBe("number") }) it("createMessage should pass correct parameters to Groq client", async () => { diff --git a/src/api/providers/groq.ts b/src/api/providers/groq.ts index 7583edc51c..de07f7c46f 100644 --- a/src/api/providers/groq.ts +++ b/src/api/providers/groq.ts @@ -1,9 +1,22 @@ import { type GroqModelId, groqDefaultModelId, groqModels } from "@roo-code/types" +import { Anthropic } from "@anthropic-ai/sdk" +import OpenAI from "openai" import type { ApiHandlerOptions } from "../../shared/api" +import type { ApiHandlerCreateMessageMetadata } from "../index" +import { ApiStream } from "../transform/stream" +import { convertToOpenAiMessages } from "../transform/openai-format" +import { calculateApiCostOpenAI } from "../../shared/cost" import { BaseOpenAiCompatibleProvider } from "./base-openai-compatible-provider" +// Enhanced usage interface to support Groq's cached token fields +interface GroqUsage extends OpenAI.CompletionUsage { + prompt_tokens_details?: { + cached_tokens?: number + } +} + export class GroqHandler extends BaseOpenAiCompatibleProvider { constructor(options: ApiHandlerOptions) { super({ @@ -16,4 +29,61 @@ export class GroqHandler extends BaseOpenAiCompatibleProvider { defaultTemperature: 0.5, }) } + + override async *createMessage( + systemPrompt: string, + messages: Anthropic.Messages.MessageParam[], + metadata?: ApiHandlerCreateMessageMetadata, + ): ApiStream { + const stream = await this.createStream(systemPrompt, messages, metadata) + + for await (const chunk of stream) { + const delta = chunk.choices[0]?.delta + + if (delta?.content) { + yield { + type: "text", + text: delta.content, + } + } + + if (chunk.usage) { + yield* this.yieldUsage(chunk.usage as GroqUsage) + } + } + } + + private async *yieldUsage(usage: GroqUsage | undefined): ApiStream { + const { info } = this.getModel() + const inputTokens = usage?.prompt_tokens || 0 + const outputTokens = usage?.completion_tokens || 0 + + const cacheReadTokens = usage?.prompt_tokens_details?.cached_tokens || 0 + + // Groq does not track cache writes + const cacheWriteTokens = 0 + + // Calculate cost using OpenAI-compatible cost calculation + const totalCost = calculateApiCostOpenAI(info, inputTokens, outputTokens, cacheWriteTokens, cacheReadTokens) + + // Calculate non-cached input tokens for proper reporting + const nonCachedInputTokens = Math.max(0, inputTokens - cacheReadTokens - cacheWriteTokens) + + console.log("usage", { + inputTokens: nonCachedInputTokens, + outputTokens, + cacheWriteTokens, + cacheReadTokens, + totalCost, + }) + + yield { + type: "usage", + inputTokens: nonCachedInputTokens, + outputTokens, + cacheWriteTokens, + cacheReadTokens, + totalCost, + } + } }