Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion packages/types/src/providers/groq.ts
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,10 @@ export const groqModels = {
maxTokens: 16384,
contextWindow: 131072,
supportsImages: false,
supportsPromptCache: false,
supportsPromptCache: true,
inputPrice: 1.0,
outputPrice: 3.0,
cacheReadsPrice: 0.5, // 50% discount for cached input tokens
description: "Moonshot AI Kimi K2 Instruct 1T model, 128K context.",
},
"openai/gpt-oss-120b": {
Expand Down
48 changes: 47 additions & 1 deletion src/api/providers/__tests__/groq.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,53 @@ describe("GroqHandler", () => {
const firstChunk = await stream.next()

expect(firstChunk.done).toBe(false)
expect(firstChunk.value).toEqual({ type: "usage", inputTokens: 10, outputTokens: 20 })
expect(firstChunk.value).toMatchObject({
type: "usage",
inputTokens: 10,
outputTokens: 20,
cacheWriteTokens: 0,
cacheReadTokens: 0,
})
// Check that totalCost is a number (we don't need to test the exact value as that's tested in cost.spec.ts)
expect(typeof firstChunk.value.totalCost).toBe("number")
})

it("createMessage should handle cached tokens in usage data", async () => {
mockCreate.mockImplementationOnce(() => {
return {
[Symbol.asyncIterator]: () => ({
next: vitest
.fn()
.mockResolvedValueOnce({
done: false,
value: {
choices: [{ delta: {} }],
usage: {
prompt_tokens: 100,
completion_tokens: 50,
prompt_tokens_details: {
cached_tokens: 30,
},
},
},
})
.mockResolvedValueOnce({ done: true }),
}),
}
})

const stream = handler.createMessage("system prompt", [])
const firstChunk = await stream.next()

expect(firstChunk.done).toBe(false)
expect(firstChunk.value).toMatchObject({
type: "usage",
inputTokens: 70, // 100 total - 30 cached
outputTokens: 50,
cacheWriteTokens: 0,
cacheReadTokens: 30,
})
expect(typeof firstChunk.value.totalCost).toBe("number")
})
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider adding edge case tests:

  • When prompt_tokens_details is present but cached_tokens is undefined
  • When cached tokens exceed total prompt tokens (error case)
  • Verify actual cost calculation values instead of just checking the type


it("createMessage should pass correct parameters to Groq client", async () => {
Expand Down
70 changes: 70 additions & 0 deletions src/api/providers/groq.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,22 @@
import { type GroqModelId, groqDefaultModelId, groqModels } from "@roo-code/types"
import { Anthropic } from "@anthropic-ai/sdk"
import OpenAI from "openai"

import type { ApiHandlerOptions } from "../../shared/api"
import type { ApiHandlerCreateMessageMetadata } from "../index"
import { ApiStream } from "../transform/stream"
import { convertToOpenAiMessages } from "../transform/openai-format"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this import still needed? It appears to be unused since the createMessage method is overridden and doesn't call convertToOpenAiMessages.

import { calculateApiCostOpenAI } from "../../shared/cost"

import { BaseOpenAiCompatibleProvider } from "./base-openai-compatible-provider"

// Enhanced usage interface to support Groq's cached token fields
interface GroqUsage extends OpenAI.CompletionUsage {
prompt_tokens_details?: {
cached_tokens?: number
}
}

export class GroqHandler extends BaseOpenAiCompatibleProvider<GroqModelId> {
constructor(options: ApiHandlerOptions) {
super({
Expand All @@ -16,4 +29,61 @@ export class GroqHandler extends BaseOpenAiCompatibleProvider<GroqModelId> {
defaultTemperature: 0.5,
})
}

override async *createMessage(
systemPrompt: string,
messages: Anthropic.Messages.MessageParam[],
metadata?: ApiHandlerCreateMessageMetadata,
): ApiStream {
const stream = await this.createStream(systemPrompt, messages, metadata)

for await (const chunk of stream) {
const delta = chunk.choices[0]?.delta

if (delta?.content) {
yield {
type: "text",
text: delta.content,
}
}

if (chunk.usage) {
yield* this.yieldUsage(chunk.usage as GroqUsage)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we add type validation here to ensure chunk.usage conforms to GroqUsage structure? The type assertion without validation could potentially cause runtime errors if the API response structure changes.

}
}
}

private async *yieldUsage(usage: GroqUsage | undefined): ApiStream {
const { info } = this.getModel()
const inputTokens = usage?.prompt_tokens || 0
const outputTokens = usage?.completion_tokens || 0

const cacheReadTokens = usage?.prompt_tokens_details?.cached_tokens || 0

// Groq does not track cache writes
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we expand this comment to provide more context? For example: 'Groq does not track cache writes - only cache reads are reported in the API response. This is a limitation of the Groq API as of [date].'

const cacheWriteTokens = 0

// Calculate cost using OpenAI-compatible cost calculation
const totalCost = calculateApiCostOpenAI(info, inputTokens, outputTokens, cacheWriteTokens, cacheReadTokens)

// Calculate non-cached input tokens for proper reporting
const nonCachedInputTokens = Math.max(0, inputTokens - cacheReadTokens - cacheWriteTokens)

console.log("usage", {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Debug logging should be removed from production code. Could we remove this console.log statement?

inputTokens: nonCachedInputTokens,
outputTokens,
cacheWriteTokens,
cacheReadTokens,
totalCost,
})

yield {
type: "usage",
inputTokens: nonCachedInputTokens,
outputTokens,
cacheWriteTokens,
cacheReadTokens,
totalCost,
}
}
}
Loading