Skip to content
6 changes: 3 additions & 3 deletions packages/types/src/provider-settings.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import {
sambaNovaModels,
vertexModels,
vscodeLlmModels,
xaiModels,
internationalZAiModels,
minimaxModels,
} from "./providers/index.js"
Expand Down Expand Up @@ -50,6 +49,7 @@ export const dynamicProviders = [
"glama",
"roo",
"chutes",
"xai",
] as const

export type DynamicProvider = (typeof dynamicProviders)[number]
Expand Down Expand Up @@ -137,7 +137,6 @@ export const providerNames = [
"roo",
"sambanova",
"vertex",
"xai",
"zai",
] as const

Expand Down Expand Up @@ -354,6 +353,7 @@ const fakeAiSchema = baseProviderSettingsSchema.extend({

const xaiSchema = apiModelIdProviderModelSchema.extend({
xaiApiKey: z.string().optional(),
xaiModelContextWindow: z.number().int().min(1).optional(),
})

const groqSchema = apiModelIdProviderModelSchema.extend({
Expand Down Expand Up @@ -709,7 +709,7 @@ export const MODELS_BY_PROVIDER: Record<
label: "VS Code LM API",
models: Object.keys(vscodeLlmModels),
},
xai: { id: "xai", label: "xAI (Grok)", models: Object.keys(xaiModels) },
xai: { id: "xai", label: "xAI (Grok)", models: [] },
zai: { id: "zai", label: "Zai", models: Object.keys(internationalZAiModels) },

// Dynamic providers; models pulled from remote APIs.
Expand Down
102 changes: 36 additions & 66 deletions packages/types/src/providers/xai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,93 +3,63 @@ import type { ModelInfo } from "../model.js"
// https://docs.x.ai/docs/api-reference
export type XAIModelId = keyof typeof xaiModels

export const xaiDefaultModelId: XAIModelId = "grok-code-fast-1"
export const xaiDefaultModelId: XAIModelId = "grok-4-fast-reasoning"

/**
* Partial ModelInfo for xAI static registry.
* Contains only fields not available from the xAI API:
* - contextWindow: Not provided by API
* - maxTokens: Not provided by API
* - description: User-friendly descriptions
* - supportsReasoningEffort: Special capability flag
*
* All other fields (pricing, supportsPromptCache, supportsImages) are fetched dynamically.
*/
type XAIStaticModelInfo = Pick<ModelInfo, "contextWindow" | "description"> & {
maxTokens?: number | null
supportsReasoningEffort?: boolean
}

export const xaiModels = {
"grok-code-fast-1": {
maxTokens: 16_384,
contextWindow: 262_144,
supportsImages: false,
supportsPromptCache: true,
inputPrice: 0.2,
outputPrice: 1.5,
cacheWritesPrice: 0.02,
cacheReadsPrice: 0.02,
contextWindow: 256_000,
description: "xAI's Grok Code Fast model with 256K context window",
},
"grok-4": {
maxTokens: 8192,
contextWindow: 256000,
supportsImages: true,
supportsPromptCache: true,
inputPrice: 3.0,
outputPrice: 15.0,
cacheWritesPrice: 0.75,
cacheReadsPrice: 0.75,
"grok-4-0709": {
maxTokens: 16_384,
contextWindow: 256_000,
description: "xAI's Grok-4 model with 256K context window",
},
"grok-4-fast-non-reasoning": {
maxTokens: 32_768,
contextWindow: 2_000_000,
description: "xAI's Grok-4 Fast Non-Reasoning model with 2M context window",
},
"grok-4-fast-reasoning": {
maxTokens: 32_768,
contextWindow: 2_000_000,
description: "xAI's Grok-4 Fast Reasoning model with 2M context window",
},
"grok-3": {
maxTokens: 8192,
contextWindow: 131072,
supportsImages: false,
supportsPromptCache: true,
inputPrice: 3.0,
outputPrice: 15.0,
cacheWritesPrice: 0.75,
cacheReadsPrice: 0.75,
contextWindow: 131_072,
description: "xAI's Grok-3 model with 128K context window",
},
"grok-3-fast": {
maxTokens: 8192,
contextWindow: 131072,
supportsImages: false,
supportsPromptCache: true,
inputPrice: 5.0,
outputPrice: 25.0,
cacheWritesPrice: 1.25,
cacheReadsPrice: 1.25,
description: "xAI's Grok-3 fast model with 128K context window",
},
"grok-3-mini": {
maxTokens: 8192,
contextWindow: 131072,
supportsImages: false,
supportsPromptCache: true,
inputPrice: 0.3,
outputPrice: 0.5,
cacheWritesPrice: 0.07,
cacheReadsPrice: 0.07,
contextWindow: 131_072,
description: "xAI's Grok-3 mini model with 128K context window",
supportsReasoningEffort: true,
},
"grok-3-mini-fast": {
maxTokens: 8192,
contextWindow: 131072,
supportsImages: false,
supportsPromptCache: true,
inputPrice: 0.6,
outputPrice: 4.0,
cacheWritesPrice: 0.15,
cacheReadsPrice: 0.15,
description: "xAI's Grok-3 mini fast model with 128K context window",
supportsReasoningEffort: true,
},
"grok-2-1212": {
maxTokens: 8192,
contextWindow: 131072,
supportsImages: false,
supportsPromptCache: false,
inputPrice: 2.0,
outputPrice: 10.0,
description: "xAI's Grok-2 model (version 1212) with 128K context window",
contextWindow: 32_768,
description: "xAI's Grok-2 model (version 1212) with 32K context window",
},
"grok-2-vision-1212": {
maxTokens: 8192,
contextWindow: 32768,
supportsImages: true,
supportsPromptCache: false,
inputPrice: 2.0,
outputPrice: 10.0,
contextWindow: 32_768,
description: "xAI's Grok-2 Vision model (version 1212) with image support and 32K context window",
},
} as const satisfies Record<string, ModelInfo>
} as const satisfies Record<string, XAIStaticModelInfo>
17 changes: 14 additions & 3 deletions src/api/providers/__tests__/xai.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,12 @@ describe("XAIHandler", () => {
it("should return default model when no model is specified", () => {
const model = handler.getModel()
expect(model.id).toBe(xaiDefaultModelId)
expect(model.info).toEqual(xaiModels[xaiDefaultModelId])
expect(model.info).toMatchObject({
contextWindow: xaiModels[xaiDefaultModelId].contextWindow,
maxTokens: xaiModels[xaiDefaultModelId].maxTokens,
description: xaiModels[xaiDefaultModelId].description,
})
expect(model.info.supportsPromptCache).toBe(false) // Placeholder until dynamic data loads
})

test("should return specified model when valid model is provided", () => {
Expand All @@ -66,7 +71,12 @@ describe("XAIHandler", () => {
const model = handlerWithModel.getModel()

expect(model.id).toBe(testModelId)
expect(model.info).toEqual(xaiModels[testModelId])
expect(model.info).toMatchObject({
contextWindow: xaiModels[testModelId].contextWindow,
maxTokens: xaiModels[testModelId].maxTokens,
description: xaiModels[testModelId].description,
})
expect(model.info.supportsPromptCache).toBe(false) // Placeholder until dynamic data loads
})

it("should include reasoning_effort parameter for mini models", async () => {
Expand Down Expand Up @@ -234,12 +244,13 @@ describe("XAIHandler", () => {

// Verify the usage data
expect(firstChunk.done).toBe(false)
expect(firstChunk.value).toEqual({
expect(firstChunk.value).toMatchObject({
type: "usage",
inputTokens: 10,
outputTokens: 20,
cacheReadTokens: 5,
cacheWriteTokens: 15,
totalCost: expect.any(Number),
})
})

Expand Down
2 changes: 1 addition & 1 deletion src/api/providers/anthropic.ts
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ export class AnthropicHandler extends BaseProvider implements SingleCompletionHa
}

if (inputTokens > 0 || outputTokens > 0 || cacheWriteTokens > 0 || cacheReadTokens > 0) {
const { totalCost } = calculateApiCostAnthropic(
const totalCost = calculateApiCostAnthropic(
this.getModel().info,
inputTokens,
outputTokens,
Expand Down
2 changes: 1 addition & 1 deletion src/api/providers/cerebras.ts
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ export class CerebrasHandler extends BaseProvider implements SingleCompletionHan
const { info } = this.getModel()
// Use actual token usage from the last request
const { inputTokens, outputTokens } = this.lastUsage
const { totalCost } = calculateApiCostOpenAI(info, inputTokens, outputTokens)
const totalCost = calculateApiCostOpenAI(info, inputTokens, outputTokens)
return totalCost
}
}
4 changes: 2 additions & 2 deletions src/api/providers/deepinfra.ts
Original file line number Diff line number Diff line change
Expand Up @@ -131,9 +131,9 @@ export class DeepInfraHandler extends RouterProvider implements SingleCompletion
const cacheWriteTokens = usage?.prompt_tokens_details?.cache_write_tokens || 0
const cacheReadTokens = usage?.prompt_tokens_details?.cached_tokens || 0

const { totalCost } = modelInfo
const totalCost = modelInfo
? calculateApiCostOpenAI(modelInfo, inputTokens, outputTokens, cacheWriteTokens, cacheReadTokens)
: { totalCost: 0 }
: 0

return {
type: "usage",
Expand Down
56 changes: 56 additions & 0 deletions src/api/providers/fetchers/__tests__/xai.spec.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import { describe, it, expect, vi, beforeEach } from "vitest"
import axios from "axios"

vi.mock("axios")

import { getXaiModels } from "../xai"
import { xaiModels } from "@roo-code/types"

describe("getXaiModels", () => {
const mockedAxios = axios as unknown as { get: ReturnType<typeof vi.fn> }

beforeEach(() => {
vi.clearAllMocks()
})

it("returns mapped models with pricing and modalities (augmenting static info when available)", async () => {
mockedAxios.get = vi.fn().mockResolvedValue({
data: {
models: [
{
id: "grok-3",
input_modalities: ["text"],
output_modalities: ["text"],
prompt_text_token_price: 2000, // 2000 fractional cents = $0.20 per 1M tokens
cached_prompt_text_token_price: 500, // 500 fractional cents = $0.05 per 1M tokens
completion_text_token_price: 10000, // 10000 fractional cents = $1.00 per 1M tokens
aliases: ["grok-3-latest"],
},
],
},
})

const result = await getXaiModels("key", "https://api.x.ai/v1")
expect(result["grok-3"]).toBeDefined()
expect(result["grok-3"]?.supportsImages).toBe(false)
expect(result["grok-3"]?.inputPrice).toBeCloseTo(0.2) // $0.20 per 1M tokens
expect(result["grok-3"]?.outputPrice).toBeCloseTo(1.0) // $1.00 per 1M tokens
expect(result["grok-3"]?.cacheReadsPrice).toBeCloseTo(0.05) // $0.05 per 1M tokens
// aliases are not added to avoid UI duplication
expect(result["grok-3-latest"]).toBeUndefined()
})

it("returns empty object on schema mismatches (graceful degradation)", async () => {
mockedAxios.get = vi.fn().mockResolvedValue({
data: { data: [{ bogus: true }] },
})
const result = await getXaiModels("key")
expect(result).toEqual({})
})

it("includes Authorization header when apiKey provided", async () => {
mockedAxios.get = vi.fn().mockResolvedValue({ data: { data: [] } })
await getXaiModels("secret")
expect((axios.get as any).mock.calls[0][1].headers.Authorization).toBe("Bearer secret")
})
})
6 changes: 5 additions & 1 deletion src/api/providers/fetchers/modelCache.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import { getDeepInfraModels } from "./deepinfra"
import { getHuggingFaceModels } from "./huggingface"
import { getRooModels } from "./roo"
import { getChutesModels } from "./chutes"
import { getXaiModels } from "./xai"

const memoryCache = new NodeCache({ stdTTL: 5 * 60, checkperiod: 5 * 60 })

Expand Down Expand Up @@ -101,6 +102,9 @@ export const getModels = async (options: GetModelsOptions): Promise<ModelRecord>
case "huggingface":
models = await getHuggingFaceModels()
break
case "xai":
models = await getXaiModels(options.apiKey, options.baseUrl)
break
case "roo": {
// Roo Code Cloud provider requires baseUrl and optional apiKey
const rooBaseUrl =
Expand All @@ -121,7 +125,7 @@ export const getModels = async (options: GetModelsOptions): Promise<ModelRecord>
// Cache the fetched models (even if empty, to signify a successful fetch with no models).
memoryCache.set(provider, models)

await writeModels(provider, models).catch((err) =>
await writeModels(provider, models || {}).catch((err) =>
console.error(`[getModels] Error writing ${provider} models to file cache:`, err),
)

Expand Down
Loading