diff --git a/packages/types/src/provider-settings.ts b/packages/types/src/provider-settings.ts index fc69948d4767..636214e9a179 100644 --- a/packages/types/src/provider-settings.ts +++ b/packages/types/src/provider-settings.ts @@ -21,7 +21,6 @@ import { sambaNovaModels, vertexModels, vscodeLlmModels, - xaiModels, internationalZAiModels, minimaxModels, } from "./providers/index.js" @@ -50,6 +49,7 @@ export const dynamicProviders = [ "glama", "roo", "chutes", + "xai", ] as const export type DynamicProvider = (typeof dynamicProviders)[number] @@ -137,7 +137,6 @@ export const providerNames = [ "roo", "sambanova", "vertex", - "xai", "zai", ] as const @@ -354,6 +353,7 @@ const fakeAiSchema = baseProviderSettingsSchema.extend({ const xaiSchema = apiModelIdProviderModelSchema.extend({ xaiApiKey: z.string().optional(), + xaiModelContextWindow: z.number().int().min(1).optional(), }) const groqSchema = apiModelIdProviderModelSchema.extend({ @@ -709,7 +709,7 @@ export const MODELS_BY_PROVIDER: Record< label: "VS Code LM API", models: Object.keys(vscodeLlmModels), }, - xai: { id: "xai", label: "xAI (Grok)", models: Object.keys(xaiModels) }, + xai: { id: "xai", label: "xAI (Grok)", models: [] }, zai: { id: "zai", label: "Zai", models: Object.keys(internationalZAiModels) }, // Dynamic providers; models pulled from remote APIs. diff --git a/packages/types/src/providers/xai.ts b/packages/types/src/providers/xai.ts index 3189e593da30..e241be77610e 100644 --- a/packages/types/src/providers/xai.ts +++ b/packages/types/src/providers/xai.ts @@ -3,93 +3,63 @@ import type { ModelInfo } from "../model.js" // https://docs.x.ai/docs/api-reference export type XAIModelId = keyof typeof xaiModels -export const xaiDefaultModelId: XAIModelId = "grok-code-fast-1" +export const xaiDefaultModelId: XAIModelId = "grok-4-fast-reasoning" + +/** + * Partial ModelInfo for xAI static registry. + * Contains only fields not available from the xAI API: + * - contextWindow: Not provided by API + * - maxTokens: Not provided by API + * - description: User-friendly descriptions + * - supportsReasoningEffort: Special capability flag + * + * All other fields (pricing, supportsPromptCache, supportsImages) are fetched dynamically. + */ +type XAIStaticModelInfo = Pick & { + maxTokens?: number | null + supportsReasoningEffort?: boolean +} export const xaiModels = { "grok-code-fast-1": { maxTokens: 16_384, - contextWindow: 262_144, - supportsImages: false, - supportsPromptCache: true, - inputPrice: 0.2, - outputPrice: 1.5, - cacheWritesPrice: 0.02, - cacheReadsPrice: 0.02, + contextWindow: 256_000, description: "xAI's Grok Code Fast model with 256K context window", }, - "grok-4": { - maxTokens: 8192, - contextWindow: 256000, - supportsImages: true, - supportsPromptCache: true, - inputPrice: 3.0, - outputPrice: 15.0, - cacheWritesPrice: 0.75, - cacheReadsPrice: 0.75, + "grok-4-0709": { + maxTokens: 16_384, + contextWindow: 256_000, description: "xAI's Grok-4 model with 256K context window", }, + "grok-4-fast-non-reasoning": { + maxTokens: 32_768, + contextWindow: 2_000_000, + description: "xAI's Grok-4 Fast Non-Reasoning model with 2M context window", + }, + "grok-4-fast-reasoning": { + maxTokens: 32_768, + contextWindow: 2_000_000, + description: "xAI's Grok-4 Fast Reasoning model with 2M context window", + }, "grok-3": { maxTokens: 8192, - contextWindow: 131072, - supportsImages: false, - supportsPromptCache: true, - inputPrice: 3.0, - outputPrice: 15.0, - cacheWritesPrice: 0.75, - cacheReadsPrice: 0.75, + contextWindow: 131_072, description: "xAI's Grok-3 model with 128K context window", }, - "grok-3-fast": { - maxTokens: 8192, - contextWindow: 131072, - supportsImages: false, - supportsPromptCache: true, - inputPrice: 5.0, - outputPrice: 25.0, - cacheWritesPrice: 1.25, - cacheReadsPrice: 1.25, - description: "xAI's Grok-3 fast model with 128K context window", - }, "grok-3-mini": { maxTokens: 8192, - contextWindow: 131072, - supportsImages: false, - supportsPromptCache: true, - inputPrice: 0.3, - outputPrice: 0.5, - cacheWritesPrice: 0.07, - cacheReadsPrice: 0.07, + contextWindow: 131_072, description: "xAI's Grok-3 mini model with 128K context window", supportsReasoningEffort: true, }, - "grok-3-mini-fast": { - maxTokens: 8192, - contextWindow: 131072, - supportsImages: false, - supportsPromptCache: true, - inputPrice: 0.6, - outputPrice: 4.0, - cacheWritesPrice: 0.15, - cacheReadsPrice: 0.15, - description: "xAI's Grok-3 mini fast model with 128K context window", - supportsReasoningEffort: true, - }, "grok-2-1212": { maxTokens: 8192, - contextWindow: 131072, - supportsImages: false, - supportsPromptCache: false, - inputPrice: 2.0, - outputPrice: 10.0, - description: "xAI's Grok-2 model (version 1212) with 128K context window", + contextWindow: 32_768, + description: "xAI's Grok-2 model (version 1212) with 32K context window", }, "grok-2-vision-1212": { maxTokens: 8192, - contextWindow: 32768, - supportsImages: true, - supportsPromptCache: false, - inputPrice: 2.0, - outputPrice: 10.0, + contextWindow: 32_768, description: "xAI's Grok-2 Vision model (version 1212) with image support and 32K context window", }, -} as const satisfies Record +} as const satisfies Record diff --git a/src/api/providers/__tests__/xai.spec.ts b/src/api/providers/__tests__/xai.spec.ts index 1d3d4a150931..ea9bc92c7dd2 100644 --- a/src/api/providers/__tests__/xai.spec.ts +++ b/src/api/providers/__tests__/xai.spec.ts @@ -57,7 +57,12 @@ describe("XAIHandler", () => { it("should return default model when no model is specified", () => { const model = handler.getModel() expect(model.id).toBe(xaiDefaultModelId) - expect(model.info).toEqual(xaiModels[xaiDefaultModelId]) + expect(model.info).toMatchObject({ + contextWindow: xaiModels[xaiDefaultModelId].contextWindow, + maxTokens: xaiModels[xaiDefaultModelId].maxTokens, + description: xaiModels[xaiDefaultModelId].description, + }) + expect(model.info.supportsPromptCache).toBe(false) // Placeholder until dynamic data loads }) test("should return specified model when valid model is provided", () => { @@ -66,7 +71,12 @@ describe("XAIHandler", () => { const model = handlerWithModel.getModel() expect(model.id).toBe(testModelId) - expect(model.info).toEqual(xaiModels[testModelId]) + expect(model.info).toMatchObject({ + contextWindow: xaiModels[testModelId].contextWindow, + maxTokens: xaiModels[testModelId].maxTokens, + description: xaiModels[testModelId].description, + }) + expect(model.info.supportsPromptCache).toBe(false) // Placeholder until dynamic data loads }) it("should include reasoning_effort parameter for mini models", async () => { @@ -234,12 +244,13 @@ describe("XAIHandler", () => { // Verify the usage data expect(firstChunk.done).toBe(false) - expect(firstChunk.value).toEqual({ + expect(firstChunk.value).toMatchObject({ type: "usage", inputTokens: 10, outputTokens: 20, cacheReadTokens: 5, cacheWriteTokens: 15, + totalCost: expect.any(Number), }) }) diff --git a/src/api/providers/anthropic.ts b/src/api/providers/anthropic.ts index 0e767ce23797..f0ddfbeef2cb 100644 --- a/src/api/providers/anthropic.ts +++ b/src/api/providers/anthropic.ts @@ -230,7 +230,7 @@ export class AnthropicHandler extends BaseProvider implements SingleCompletionHa } if (inputTokens > 0 || outputTokens > 0 || cacheWriteTokens > 0 || cacheReadTokens > 0) { - const { totalCost } = calculateApiCostAnthropic( + const totalCost = calculateApiCostAnthropic( this.getModel().info, inputTokens, outputTokens, diff --git a/src/api/providers/cerebras.ts b/src/api/providers/cerebras.ts index 16dfa282adb1..44c405eeb4ad 100644 --- a/src/api/providers/cerebras.ts +++ b/src/api/providers/cerebras.ts @@ -331,7 +331,7 @@ export class CerebrasHandler extends BaseProvider implements SingleCompletionHan const { info } = this.getModel() // Use actual token usage from the last request const { inputTokens, outputTokens } = this.lastUsage - const { totalCost } = calculateApiCostOpenAI(info, inputTokens, outputTokens) + const totalCost = calculateApiCostOpenAI(info, inputTokens, outputTokens) return totalCost } } diff --git a/src/api/providers/deepinfra.ts b/src/api/providers/deepinfra.ts index fb8c117ae013..7cf018b069f5 100644 --- a/src/api/providers/deepinfra.ts +++ b/src/api/providers/deepinfra.ts @@ -131,9 +131,9 @@ export class DeepInfraHandler extends RouterProvider implements SingleCompletion const cacheWriteTokens = usage?.prompt_tokens_details?.cache_write_tokens || 0 const cacheReadTokens = usage?.prompt_tokens_details?.cached_tokens || 0 - const { totalCost } = modelInfo + const totalCost = modelInfo ? calculateApiCostOpenAI(modelInfo, inputTokens, outputTokens, cacheWriteTokens, cacheReadTokens) - : { totalCost: 0 } + : 0 return { type: "usage", diff --git a/src/api/providers/fetchers/__tests__/xai.spec.ts b/src/api/providers/fetchers/__tests__/xai.spec.ts new file mode 100644 index 000000000000..82ed51352805 --- /dev/null +++ b/src/api/providers/fetchers/__tests__/xai.spec.ts @@ -0,0 +1,56 @@ +import { describe, it, expect, vi, beforeEach } from "vitest" +import axios from "axios" + +vi.mock("axios") + +import { getXaiModels } from "../xai" +import { xaiModels } from "@roo-code/types" + +describe("getXaiModels", () => { + const mockedAxios = axios as unknown as { get: ReturnType } + + beforeEach(() => { + vi.clearAllMocks() + }) + + it("returns mapped models with pricing and modalities (augmenting static info when available)", async () => { + mockedAxios.get = vi.fn().mockResolvedValue({ + data: { + models: [ + { + id: "grok-3", + input_modalities: ["text"], + output_modalities: ["text"], + prompt_text_token_price: 2000, // 2000 fractional cents = $0.20 per 1M tokens + cached_prompt_text_token_price: 500, // 500 fractional cents = $0.05 per 1M tokens + completion_text_token_price: 10000, // 10000 fractional cents = $1.00 per 1M tokens + aliases: ["grok-3-latest"], + }, + ], + }, + }) + + const result = await getXaiModels("key", "https://api.x.ai/v1") + expect(result["grok-3"]).toBeDefined() + expect(result["grok-3"]?.supportsImages).toBe(false) + expect(result["grok-3"]?.inputPrice).toBeCloseTo(0.2) // $0.20 per 1M tokens + expect(result["grok-3"]?.outputPrice).toBeCloseTo(1.0) // $1.00 per 1M tokens + expect(result["grok-3"]?.cacheReadsPrice).toBeCloseTo(0.05) // $0.05 per 1M tokens + // aliases are not added to avoid UI duplication + expect(result["grok-3-latest"]).toBeUndefined() + }) + + it("returns empty object on schema mismatches (graceful degradation)", async () => { + mockedAxios.get = vi.fn().mockResolvedValue({ + data: { data: [{ bogus: true }] }, + }) + const result = await getXaiModels("key") + expect(result).toEqual({}) + }) + + it("includes Authorization header when apiKey provided", async () => { + mockedAxios.get = vi.fn().mockResolvedValue({ data: { data: [] } }) + await getXaiModels("secret") + expect((axios.get as any).mock.calls[0][1].headers.Authorization).toBe("Bearer secret") + }) +}) diff --git a/src/api/providers/fetchers/modelCache.ts b/src/api/providers/fetchers/modelCache.ts index 722e66dd7286..860178bf49fa 100644 --- a/src/api/providers/fetchers/modelCache.ts +++ b/src/api/providers/fetchers/modelCache.ts @@ -26,6 +26,7 @@ import { getDeepInfraModels } from "./deepinfra" import { getHuggingFaceModels } from "./huggingface" import { getRooModels } from "./roo" import { getChutesModels } from "./chutes" +import { getXaiModels } from "./xai" const memoryCache = new NodeCache({ stdTTL: 5 * 60, checkperiod: 5 * 60 }) @@ -101,6 +102,9 @@ export const getModels = async (options: GetModelsOptions): Promise case "huggingface": models = await getHuggingFaceModels() break + case "xai": + models = await getXaiModels(options.apiKey, options.baseUrl) + break case "roo": { // Roo Code Cloud provider requires baseUrl and optional apiKey const rooBaseUrl = @@ -121,7 +125,7 @@ export const getModels = async (options: GetModelsOptions): Promise // Cache the fetched models (even if empty, to signify a successful fetch with no models). memoryCache.set(provider, models) - await writeModels(provider, models).catch((err) => + await writeModels(provider, models || {}).catch((err) => console.error(`[getModels] Error writing ${provider} models to file cache:`, err), ) diff --git a/src/api/providers/fetchers/xai.ts b/src/api/providers/fetchers/xai.ts new file mode 100644 index 000000000000..4ca861864d00 --- /dev/null +++ b/src/api/providers/fetchers/xai.ts @@ -0,0 +1,107 @@ +import axios from "axios" +import { z } from "zod" + +import { type ModelInfo, xaiModels } from "@roo-code/types" +import { DEFAULT_HEADERS } from "../../providers/constants" + +/** + * Schema for GET https://api.x.ai/v1/language-models + * This endpoint returns rich metadata including modalities and pricing. + */ +const xaiLanguageModelSchema = z.object({ + id: z.string(), + input_modalities: z.array(z.string()).optional(), + output_modalities: z.array(z.string()).optional(), + prompt_text_token_price: z.number().optional(), // fractional cents (basis points) per 1M tokens + cached_prompt_text_token_price: z.number().optional(), // fractional cents per 1M tokens + prompt_image_token_price: z.number().optional(), // fractional cents per 1M tokens + completion_text_token_price: z.number().optional(), // fractional cents per 1M tokens + search_price: z.number().optional(), + aliases: z.array(z.string()).optional(), +}) + +const xaiLanguageModelsResponseSchema = z.object({ + models: z.array(xaiLanguageModelSchema), +}) + +/** + * Fetch available xAI models for the authenticated account. + * - Uses Bearer Authorization header when apiKey is provided + * - Maps discovered IDs to ModelInfo using static catalog (xaiModels) when possible + * - For models not in static catalog, contextWindow and maxTokens remain undefined + */ +export async function getXaiModels(apiKey?: string, baseUrl?: string): Promise> { + const models: Record = {} + // Build proper endpoint whether user passes https://api.x.ai or https://api.x.ai/v1 + const base = baseUrl ? baseUrl.replace(/\/+$/, "") : "https://api.x.ai" + const url = base.endsWith("/v1") ? `${base}/language-models` : `${base}/v1/language-models` + + try { + const resp = await axios.get(url, { + headers: { + ...DEFAULT_HEADERS, + Accept: "application/json", + ...(apiKey ? { Authorization: `Bearer ${apiKey}` } : {}), + }, + }) + + const parsed = xaiLanguageModelsResponseSchema.safeParse(resp.data) + const items = parsed.success + ? parsed.data.models + : Array.isArray((resp.data as any)?.models) + ? (resp.data as any)?.models + : [] + + if (!parsed.success) { + console.error("xAI language models response validation failed", parsed.error?.format?.() ?? parsed.error) + } + + // Helper to convert fractional-cents-per-1M (basis points) to dollars-per-1M + // The API returns values in 1/100th of a cent, so divide by 10,000 to get dollars + const centsToDollars = (v?: number) => (typeof v === "number" ? v / 10_000 : undefined) + + for (const m of items) { + const id = m.id + const staticInfo = xaiModels[id as keyof typeof xaiModels] + const supportsImages = Array.isArray(m.input_modalities) ? m.input_modalities.includes("image") : false + + // Cache support is indicated by presence of cached_prompt_text_token_price field (even if 0) + const supportsPromptCache = typeof m.cached_prompt_text_token_price === "number" + const cacheReadsPrice = supportsPromptCache ? centsToDollars(m.cached_prompt_text_token_price) : undefined + + const info: ModelInfo = { + maxTokens: staticInfo?.maxTokens ?? undefined, + contextWindow: staticInfo?.contextWindow ?? undefined, + supportsImages, + supportsPromptCache, + inputPrice: centsToDollars(m.prompt_text_token_price), + outputPrice: centsToDollars(m.completion_text_token_price), + cacheReadsPrice, + cacheWritesPrice: undefined, // Leave undefined unless API exposes a distinct write price + description: staticInfo?.description, + supportsReasoningEffort: + staticInfo && "supportsReasoningEffort" in staticInfo + ? staticInfo.supportsReasoningEffort + : undefined, + // leave other optional fields undefined unless available via static definitions + } + + models[id] = info + // Aliases are not added to the model list to avoid duplication in UI + // Users should use the primary model ID; xAI API will handle alias resolution + } + } catch (error) { + // Avoid logging sensitive data like Authorization headers + if (axios.isAxiosError(error)) { + const status = error.response?.status + const statusText = error.response?.statusText + const url = (error as any)?.config?.url + console.error(`[xAI] models fetch failed: ${status ?? "unknown"} ${statusText ?? ""} ${url ?? ""}`.trim()) + } else { + console.error("[xAI] models fetch failed.", error instanceof Error ? error.message : String(error)) + } + throw error + } + + return models +} diff --git a/src/api/providers/groq.ts b/src/api/providers/groq.ts index c2f2dd19db98..b66e42d7f016 100644 --- a/src/api/providers/groq.ts +++ b/src/api/providers/groq.ts @@ -64,7 +64,7 @@ export class GroqHandler extends BaseOpenAiCompatibleProvider { const cacheWriteTokens = 0 // Calculate cost using OpenAI-compatible cost calculation - const { totalCost } = calculateApiCostOpenAI(info, inputTokens, outputTokens, cacheWriteTokens, cacheReadTokens) + const totalCost = calculateApiCostOpenAI(info, inputTokens, outputTokens, cacheWriteTokens, cacheReadTokens) yield { type: "usage", diff --git a/src/api/providers/lite-llm.ts b/src/api/providers/lite-llm.ts index 43bf33c38be0..0c6b0dd73eb4 100644 --- a/src/api/providers/lite-llm.ts +++ b/src/api/providers/lite-llm.ts @@ -165,7 +165,7 @@ export class LiteLLMHandler extends RouterProvider implements SingleCompletionHa (lastUsage as any).prompt_cache_hit_tokens || 0 - const { totalCost } = calculateApiCostOpenAI( + const totalCost = calculateApiCostOpenAI( info, lastUsage.prompt_tokens || 0, lastUsage.completion_tokens || 0, diff --git a/src/api/providers/openai-native.ts b/src/api/providers/openai-native.ts index daf6278822b5..42173979cc88 100644 --- a/src/api/providers/openai-native.ts +++ b/src/api/providers/openai-native.ts @@ -100,7 +100,7 @@ export class OpenAiNativeHandler extends BaseProvider implements SingleCompletio // Pass total input tokens directly to calculateApiCostOpenAI // The function handles subtracting both cache reads and writes internally - const { totalCost } = calculateApiCostOpenAI( + const totalCost = calculateApiCostOpenAI( effectiveInfo, totalInputTokens, totalOutputTokens, diff --git a/src/api/providers/requesty.ts b/src/api/providers/requesty.ts index 1c0e9ed64075..16aefae52861 100644 --- a/src/api/providers/requesty.ts +++ b/src/api/providers/requesty.ts @@ -85,9 +85,9 @@ export class RequestyHandler extends BaseProvider implements SingleCompletionHan const outputTokens = requestyUsage?.completion_tokens || 0 const cacheWriteTokens = requestyUsage?.prompt_tokens_details?.caching_tokens || 0 const cacheReadTokens = requestyUsage?.prompt_tokens_details?.cached_tokens || 0 - const { totalCost } = modelInfo + const totalCost = modelInfo ? calculateApiCostOpenAI(modelInfo, inputTokens, outputTokens, cacheWriteTokens, cacheReadTokens) - : { totalCost: 0 } + : 0 return { type: "usage", diff --git a/src/api/providers/xai.ts b/src/api/providers/xai.ts index 7eb6e9866dd8..e019a44b3dd5 100644 --- a/src/api/providers/xai.ts +++ b/src/api/providers/xai.ts @@ -1,7 +1,7 @@ import { Anthropic } from "@anthropic-ai/sdk" import OpenAI from "openai" -import { type XAIModelId, xaiDefaultModelId, xaiModels } from "@roo-code/types" +import { type XAIModelId, xaiDefaultModelId, xaiModels, type ModelInfo } from "@roo-code/types" import type { ApiHandlerOptions } from "../../shared/api" @@ -13,6 +13,9 @@ import { DEFAULT_HEADERS } from "./constants" import { BaseProvider } from "./base-provider" import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index" import { handleOpenAIError } from "./utils/openai-error-handler" +import { calculateApiCostOpenAI } from "../../shared/cost" +import type { ModelRecord } from "../../shared/api" +import { getModels } from "./fetchers/modelCache" const XAI_DEFAULT_TEMPERATURE = 0 @@ -20,6 +23,7 @@ export class XAIHandler extends BaseProvider implements SingleCompletionHandler protected options: ApiHandlerOptions private client: OpenAI private readonly providerName = "xAI" + protected models: ModelRecord = {} constructor(options: ApiHandlerOptions) { super() @@ -35,21 +39,49 @@ export class XAIHandler extends BaseProvider implements SingleCompletionHandler } override getModel() { - const id = - this.options.apiModelId && this.options.apiModelId in xaiModels - ? (this.options.apiModelId as XAIModelId) - : xaiDefaultModelId + // Allow any model ID (dynamic discovery) and augment with static info when available + const id = this.options.apiModelId ?? xaiDefaultModelId + + const staticInfo = (xaiModels as Record)[id as any] + const dynamicInfo = this.models?.[id as any] + + // Build complete ModelInfo using dynamic pricing/capabilities when available + const info: ModelInfo = { + contextWindow: this.options.xaiModelContextWindow ?? staticInfo?.contextWindow, + maxTokens: staticInfo?.maxTokens ?? undefined, + supportsPromptCache: dynamicInfo?.supportsPromptCache ?? false, + supportsImages: dynamicInfo?.supportsImages, + inputPrice: dynamicInfo?.inputPrice, + outputPrice: dynamicInfo?.outputPrice, + cacheReadsPrice: dynamicInfo?.cacheReadsPrice, + cacheWritesPrice: dynamicInfo?.cacheWritesPrice, + description: staticInfo?.description, + supportsReasoningEffort: + staticInfo && "supportsReasoningEffort" in staticInfo ? staticInfo.supportsReasoningEffort : undefined, + } - const info = xaiModels[id] const params = getModelParams({ format: "openai", modelId: id, model: info, settings: this.options }) return { id, info, ...params } } + private async loadDynamicModels(): Promise { + try { + this.models = await getModels({ + provider: "xai", + apiKey: this.options.xaiApiKey, + baseUrl: (this.client as any).baseURL || "https://api.x.ai/v1", + }) + } catch (error) { + console.error("[XAI] Error loading dynamic models:", error) + } + } + override async *createMessage( systemPrompt: string, messages: Anthropic.Messages.MessageParam[], metadata?: ApiHandlerCreateMessageMetadata, ): ApiStream { + await this.loadDynamicModels() const { id: modelId, info: modelInfo, reasoning } = this.getModel() // Use the OpenAI-compatible API. @@ -57,7 +89,7 @@ export class XAIHandler extends BaseProvider implements SingleCompletionHandler try { stream = await this.client.chat.completions.create({ model: modelId, - max_tokens: modelInfo.maxTokens, + ...(typeof modelInfo.maxTokens === "number" ? { max_tokens: modelInfo.maxTokens } : {}), temperature: this.options.modelTemperature ?? XAI_DEFAULT_TEMPERATURE, messages: [{ role: "system", content: systemPrompt }, ...convertToOpenAiMessages(messages)], stream: true, @@ -98,12 +130,21 @@ export class XAIHandler extends BaseProvider implements SingleCompletionHandler const writeTokens = "cache_creation_input_tokens" in chunk.usage ? (chunk.usage as any).cache_creation_input_tokens : 0 + const totalCost = calculateApiCostOpenAI( + modelInfo, + chunk.usage.prompt_tokens || 0, + chunk.usage.completion_tokens || 0, + writeTokens || 0, + readTokens || 0, + ) + yield { type: "usage", inputTokens: chunk.usage.prompt_tokens || 0, outputTokens: chunk.usage.completion_tokens || 0, cacheReadTokens: readTokens, cacheWriteTokens: writeTokens, + totalCost, } } } diff --git a/src/core/task/Task.ts b/src/core/task/Task.ts index f4f7058d6f0d..1045d4a242a2 100644 --- a/src/core/task/Task.ts +++ b/src/core/task/Task.ts @@ -2003,11 +2003,11 @@ export class Task extends EventEmitter implements TaskLike { this.clineMessages[lastApiReqIndex].text = JSON.stringify({ ...existingData, - tokensIn: costResult.totalInputTokens, - tokensOut: costResult.totalOutputTokens, + tokensIn: inputTokens, + tokensOut: outputTokens, cacheWrites: cacheWriteTokens, cacheReads: cacheReadTokens, - cost: totalCost ?? costResult.totalCost, + cost: totalCost ?? costResult, cancelReason, streamingFailedMessage, } satisfies ClineApiReqInfo) @@ -2234,11 +2234,11 @@ export class Task extends EventEmitter implements TaskLike { ) TelemetryService.instance.captureLlmCompletion(this.taskId, { - inputTokens: costResult.totalInputTokens, - outputTokens: costResult.totalOutputTokens, + inputTokens: tokens.input, + outputTokens: tokens.output, cacheWriteTokens: tokens.cacheWrite, cacheReadTokens: tokens.cacheRead, - cost: tokens.total ?? costResult.totalCost, + cost: tokens.total ?? costResult, }) } } diff --git a/src/core/webview/__tests__/ClineProvider.spec.ts b/src/core/webview/__tests__/ClineProvider.spec.ts index a8ab39108d95..0acb1fe0c494 100644 --- a/src/core/webview/__tests__/ClineProvider.spec.ts +++ b/src/core/webview/__tests__/ClineProvider.spec.ts @@ -2720,6 +2720,7 @@ describe("ClineProvider - Router Models", () => { "vercel-ai-gateway": mockModels, huggingface: {}, "io-intelligence": {}, + xai: {}, }, values: undefined, }) @@ -2776,6 +2777,7 @@ describe("ClineProvider - Router Models", () => { "vercel-ai-gateway": mockModels, huggingface: {}, "io-intelligence": {}, + xai: {}, }, values: undefined, }) @@ -2900,6 +2902,7 @@ describe("ClineProvider - Router Models", () => { "vercel-ai-gateway": mockModels, huggingface: {}, "io-intelligence": {}, + xai: {}, }, values: undefined, }) diff --git a/src/core/webview/__tests__/webviewMessageHandler.spec.ts b/src/core/webview/__tests__/webviewMessageHandler.spec.ts index 3fd2a47f3778..fd7f11cec621 100644 --- a/src/core/webview/__tests__/webviewMessageHandler.spec.ts +++ b/src/core/webview/__tests__/webviewMessageHandler.spec.ts @@ -255,6 +255,7 @@ describe("webviewMessageHandler - requestRouterModels", () => { "vercel-ai-gateway": mockModels, huggingface: {}, "io-intelligence": {}, + xai: {}, }, values: undefined, }) @@ -349,6 +350,7 @@ describe("webviewMessageHandler - requestRouterModels", () => { "vercel-ai-gateway": mockModels, huggingface: {}, "io-intelligence": {}, + xai: {}, }, values: undefined, }) @@ -380,7 +382,7 @@ describe("webviewMessageHandler - requestRouterModels", () => { type: "requestRouterModels", }) - // Verify error messages were sent for failed providers (these come first) + // Verify error messages were sent for failed providers expect(mockClineProvider.postMessageToWebview).toHaveBeenCalledWith({ type: "singleRouterModelFetchResponse", success: false, @@ -426,6 +428,7 @@ describe("webviewMessageHandler - requestRouterModels", () => { "vercel-ai-gateway": mockModels, huggingface: {}, "io-intelligence": {}, + xai: {}, }, values: undefined, }) diff --git a/src/core/webview/webviewMessageHandler.ts b/src/core/webview/webviewMessageHandler.ts index 0a4bd9abb90c..d982861c3fc9 100644 --- a/src/core/webview/webviewMessageHandler.ts +++ b/src/core/webview/webviewMessageHandler.ts @@ -777,6 +777,7 @@ export const webviewMessageHandler = async ( lmstudio: {}, roo: {}, chutes: {}, + xai: {}, } const safeGetModels = async (options: GetModelsOptions): Promise => { @@ -838,6 +839,14 @@ export const webviewMessageHandler = async ( }) } + // Add xAI if API key is provided. + if (apiConfiguration.xaiApiKey) { + candidates.push({ + key: "xai", + options: { provider: "xai", apiKey: apiConfiguration.xaiApiKey, baseUrl: "https://api.x.ai/v1" }, + }) + } + // LiteLLM is conditional on baseUrl+apiKey const litellmApiKey = apiConfiguration.litellmApiKey || message?.values?.litellmApiKey const litellmBaseUrl = apiConfiguration.litellmBaseUrl || message?.values?.litellmBaseUrl diff --git a/src/shared/__tests__/cost.spec.ts b/src/shared/__tests__/cost.spec.ts new file mode 100644 index 000000000000..563cfc29fb4b --- /dev/null +++ b/src/shared/__tests__/cost.spec.ts @@ -0,0 +1,114 @@ +import { describe, expect, it } from "vitest" +import { parseApiPrice, calculateApiCostAnthropic, calculateApiCostOpenAI } from "../cost" +import type { ModelInfo } from "@roo-code/types" + +describe("parseApiPrice", () => { + it("should handle zero as a number", () => { + expect(parseApiPrice(0)).toBe(0) + }) + + it("should handle zero as a string", () => { + expect(parseApiPrice("0")).toBe(0) + }) + + it("should handle positive numbers", () => { + expect(parseApiPrice(0.0002)).toBe(200) + expect(parseApiPrice(0.00002)).toBe(20) + }) + + it("should handle positive number strings", () => { + expect(parseApiPrice("0.0002")).toBe(200) + expect(parseApiPrice("0.00002")).toBe(20) + }) + + it("should return undefined for null", () => { + expect(parseApiPrice(null)).toBeUndefined() + }) + + it("should return undefined for undefined", () => { + expect(parseApiPrice(undefined)).toBeUndefined() + }) + + it("should return undefined for empty string", () => { + expect(parseApiPrice("")).toBeUndefined() + }) +}) + +describe("calculateApiCostAnthropic", () => { + const modelInfo: ModelInfo = { + maxTokens: 4096, + contextWindow: 200000, + supportsImages: true, + supportsPromptCache: true, + inputPrice: 300, + outputPrice: 1500, + cacheWritesPrice: 375, + cacheReadsPrice: 30, + } + + it("should calculate cost without caching", () => { + const cost = calculateApiCostAnthropic(modelInfo, 1000, 500) + expect(cost).toBeCloseTo(0.3 + 0.75, 10) + }) + + it("should calculate cost with cache creation", () => { + const cost = calculateApiCostAnthropic(modelInfo, 1000, 500, 2000) + expect(cost).toBeCloseTo(0.3 + 0.75 + 0.75, 10) + }) + + it("should calculate cost with cache reads", () => { + const cost = calculateApiCostAnthropic(modelInfo, 1000, 500, 0, 3000) + expect(cost).toBeCloseTo(0.3 + 0.75 + 0.09, 10) + }) + + it("should handle zero cost for free models", () => { + const freeModel: ModelInfo = { + maxTokens: 4096, + contextWindow: 200000, + supportsImages: false, + supportsPromptCache: false, + inputPrice: 0, + outputPrice: 0, + } + const cost = calculateApiCostAnthropic(freeModel, 1000, 500) + expect(cost).toBe(0) + }) +}) + +describe("calculateApiCostOpenAI", () => { + const modelInfo: ModelInfo = { + maxTokens: 4096, + contextWindow: 128000, + supportsImages: true, + supportsPromptCache: true, + inputPrice: 150, + outputPrice: 600, + cacheWritesPrice: 187.5, + cacheReadsPrice: 15, + } + + it("should calculate cost without caching", () => { + const cost = calculateApiCostOpenAI(modelInfo, 1000, 500) + expect(cost).toBeCloseTo(0.15 + 0.3, 10) + }) + + it("should subtract cached tokens from input tokens", () => { + const cost = calculateApiCostOpenAI(modelInfo, 5000, 500, 2000, 1000) + // 5000 total - 2000 cache creation - 1000 cache read = 2000 non-cached + // Cost: (2000 * 0.00015) + (2000 * 0.0001875) + (1000 * 0.000015) + (500 * 0.0006) + expect(cost).toBeCloseTo(0.3 + 0.375 + 0.015 + 0.3, 10) + }) + + it("should handle zero cost for free models", () => { + const freeModel: ModelInfo = { + maxTokens: 4096, + contextWindow: 128000, + supportsImages: false, + supportsPromptCache: false, + inputPrice: 0, + outputPrice: 0, + } + const cost = calculateApiCostOpenAI(freeModel, 1000, 500) + expect(cost).toBe(0) + }) +}) diff --git a/src/shared/api.ts b/src/shared/api.ts index 802654adaad9..d788a922db8d 100644 --- a/src/shared/api.ts +++ b/src/shared/api.ts @@ -165,6 +165,7 @@ const dynamicProviderExtras = { lmstudio: {} as {}, // eslint-disable-line @typescript-eslint/no-empty-object-type roo: {} as { apiKey?: string; baseUrl?: string }, chutes: {} as { apiKey?: string }, + xai: {} as { apiKey?: string; baseUrl?: string }, } as const satisfies Record // Build the dynamic options union from the map, intersected with CommonFetchParams diff --git a/src/shared/cost.ts b/src/shared/cost.ts index fea686d8aed8..fcff3aeaf006 100644 --- a/src/shared/cost.ts +++ b/src/shared/cost.ts @@ -12,20 +12,15 @@ function calculateApiCostInternal( outputTokens: number, cacheCreationInputTokens: number, cacheReadInputTokens: number, - totalInputTokens: number, - totalOutputTokens: number, -): ApiCostResult { + totalInputTokens: number, // kept for potential future use + totalOutputTokens: number, // kept for potential future use +): number { const cacheWritesCost = ((modelInfo.cacheWritesPrice || 0) / 1_000_000) * cacheCreationInputTokens const cacheReadsCost = ((modelInfo.cacheReadsPrice || 0) / 1_000_000) * cacheReadInputTokens const baseInputCost = ((modelInfo.inputPrice || 0) / 1_000_000) * inputTokens const outputCost = ((modelInfo.outputPrice || 0) / 1_000_000) * outputTokens const totalCost = cacheWritesCost + cacheReadsCost + baseInputCost + outputCost - - return { - totalInputTokens, - totalOutputTokens, - totalCost, - } + return totalCost } // For Anthropic compliant usage, the input tokens count does NOT include the @@ -36,7 +31,7 @@ export function calculateApiCostAnthropic( outputTokens: number, cacheCreationInputTokens?: number, cacheReadInputTokens?: number, -): ApiCostResult { +): number { const cacheCreation = cacheCreationInputTokens || 0 const cacheRead = cacheReadInputTokens || 0 @@ -62,7 +57,7 @@ export function calculateApiCostOpenAI( outputTokens: number, cacheCreationInputTokens?: number, cacheReadInputTokens?: number, -): ApiCostResult { +): number { const cacheCreationInputTokensNum = cacheCreationInputTokens || 0 const cacheReadInputTokensNum = cacheReadInputTokens || 0 const nonCachedInputTokens = Math.max(0, inputTokens - cacheCreationInputTokensNum - cacheReadInputTokensNum) @@ -80,4 +75,8 @@ export function calculateApiCostOpenAI( ) } -export const parseApiPrice = (price: any) => (price ? parseFloat(price) * 1_000_000 : undefined) +export const parseApiPrice = (price: any) => { + if (price == null) return undefined + const parsed = parseFloat(price) + return isNaN(parsed) ? undefined : parsed * 1_000_000 +} diff --git a/src/utils/__tests__/cost.spec.ts b/src/utils/__tests__/cost.spec.ts index 83d268713697..1bf585c4054f 100644 --- a/src/utils/__tests__/cost.spec.ts +++ b/src/utils/__tests__/cost.spec.ts @@ -22,9 +22,7 @@ describe("Cost Utility", () => { // Input cost: (3.0 / 1_000_000) * 1000 = 0.003 // Output cost: (15.0 / 1_000_000) * 500 = 0.0075 // Total: 0.003 + 0.0075 = 0.0105 - expect(result.totalCost).toBe(0.0105) - expect(result.totalInputTokens).toBe(1000) - expect(result.totalOutputTokens).toBe(500) + expect(result).toBe(0.0105) }) it("should handle cache writes cost", () => { @@ -34,9 +32,7 @@ describe("Cost Utility", () => { // Output cost: (15.0 / 1_000_000) * 500 = 0.0075 // Cache writes: (3.75 / 1_000_000) * 2000 = 0.0075 // Total: 0.003 + 0.0075 + 0.0075 = 0.018 - expect(result.totalCost).toBeCloseTo(0.018, 6) - expect(result.totalInputTokens).toBe(3000) // 1000 + 2000 - expect(result.totalOutputTokens).toBe(500) + expect(result).toBeCloseTo(0.018, 6) }) it("should handle cache reads cost", () => { @@ -46,9 +42,7 @@ describe("Cost Utility", () => { // Output cost: (15.0 / 1_000_000) * 500 = 0.0075 // Cache reads: (0.3 / 1_000_000) * 3000 = 0.0009 // Total: 0.003 + 0.0075 + 0.0009 = 0.0114 - expect(result.totalCost).toBe(0.0114) - expect(result.totalInputTokens).toBe(4000) // 1000 + 3000 - expect(result.totalOutputTokens).toBe(500) + expect(result).toBe(0.0114) }) it("should handle all cost components together", () => { @@ -59,9 +53,7 @@ describe("Cost Utility", () => { // Cache writes: (3.75 / 1_000_000) * 2000 = 0.0075 // Cache reads: (0.3 / 1_000_000) * 3000 = 0.0009 // Total: 0.003 + 0.0075 + 0.0075 + 0.0009 = 0.0189 - expect(result.totalCost).toBe(0.0189) - expect(result.totalInputTokens).toBe(6000) // 1000 + 2000 + 3000 - expect(result.totalOutputTokens).toBe(500) + expect(result).toBe(0.0189) }) it("should handle missing prices gracefully", () => { @@ -72,16 +64,12 @@ describe("Cost Utility", () => { } const result = calculateApiCostAnthropic(modelWithoutPrices, 1000, 500, 2000, 3000) - expect(result.totalCost).toBe(0) - expect(result.totalInputTokens).toBe(6000) // 1000 + 2000 + 3000 - expect(result.totalOutputTokens).toBe(500) + expect(result).toBe(0) }) it("should handle zero tokens", () => { const result = calculateApiCostAnthropic(mockModelInfo, 0, 0, 0, 0) - expect(result.totalCost).toBe(0) - expect(result.totalInputTokens).toBe(0) - expect(result.totalOutputTokens).toBe(0) + expect(result).toBe(0) }) it("should handle undefined cache values", () => { @@ -90,9 +78,7 @@ describe("Cost Utility", () => { // Input cost: (3.0 / 1_000_000) * 1000 = 0.003 // Output cost: (15.0 / 1_000_000) * 500 = 0.0075 // Total: 0.003 + 0.0075 = 0.0105 - expect(result.totalCost).toBe(0.0105) - expect(result.totalInputTokens).toBe(1000) - expect(result.totalOutputTokens).toBe(500) + expect(result).toBe(0.0105) }) it("should handle missing cache prices", () => { @@ -108,9 +94,7 @@ describe("Cost Utility", () => { // Input cost: (3.0 / 1_000_000) * 1000 = 0.003 // Output cost: (15.0 / 1_000_000) * 500 = 0.0075 // Total: 0.003 + 0.0075 = 0.0105 - expect(result.totalCost).toBe(0.0105) - expect(result.totalInputTokens).toBe(6000) // 1000 + 2000 + 3000 - expect(result.totalOutputTokens).toBe(500) + expect(result).toBe(0.0105) }) }) @@ -131,9 +115,7 @@ describe("Cost Utility", () => { // Input cost: (3.0 / 1_000_000) * 1000 = 0.003 // Output cost: (15.0 / 1_000_000) * 500 = 0.0075 // Total: 0.003 + 0.0075 = 0.0105 - expect(result.totalCost).toBe(0.0105) - expect(result.totalInputTokens).toBe(1000) - expect(result.totalOutputTokens).toBe(500) + expect(result).toBe(0.0105) }) it("should handle cache writes cost", () => { @@ -143,9 +125,7 @@ describe("Cost Utility", () => { // Output cost: (15.0 / 1_000_000) * 500 = 0.0075 // Cache writes: (3.75 / 1_000_000) * 2000 = 0.0075 // Total: 0.003 + 0.0075 + 0.0075 = 0.018 - expect(result.totalCost).toBeCloseTo(0.018, 6) - expect(result.totalInputTokens).toBe(3000) // Total already includes cache - expect(result.totalOutputTokens).toBe(500) + expect(result).toBeCloseTo(0.018, 6) }) it("should handle cache reads cost", () => { @@ -155,9 +135,7 @@ describe("Cost Utility", () => { // Output cost: (15.0 / 1_000_000) * 500 = 0.0075 // Cache reads: (0.3 / 1_000_000) * 3000 = 0.0009 // Total: 0.003 + 0.0075 + 0.0009 = 0.0114 - expect(result.totalCost).toBe(0.0114) - expect(result.totalInputTokens).toBe(4000) // Total already includes cache - expect(result.totalOutputTokens).toBe(500) + expect(result).toBe(0.0114) }) it("should handle all cost components together", () => { @@ -168,9 +146,7 @@ describe("Cost Utility", () => { // Cache writes: (3.75 / 1_000_000) * 2000 = 0.0075 // Cache reads: (0.3 / 1_000_000) * 3000 = 0.0009 // Total: 0.003 + 0.0075 + 0.0075 + 0.0009 = 0.0189 - expect(result.totalCost).toBe(0.0189) - expect(result.totalInputTokens).toBe(6000) // Total already includes cache - expect(result.totalOutputTokens).toBe(500) + expect(result).toBe(0.0189) }) it("should handle missing prices gracefully", () => { @@ -181,16 +157,12 @@ describe("Cost Utility", () => { } const result = calculateApiCostOpenAI(modelWithoutPrices, 1000, 500, 2000, 3000) - expect(result.totalCost).toBe(0) - expect(result.totalInputTokens).toBe(1000) // Total already includes cache - expect(result.totalOutputTokens).toBe(500) + expect(result).toBe(0) }) it("should handle zero tokens", () => { const result = calculateApiCostOpenAI(mockModelInfo, 0, 0, 0, 0) - expect(result.totalCost).toBe(0) - expect(result.totalInputTokens).toBe(0) - expect(result.totalOutputTokens).toBe(0) + expect(result).toBe(0) }) it("should handle undefined cache values", () => { @@ -199,9 +171,7 @@ describe("Cost Utility", () => { // Input cost: (3.0 / 1_000_000) * 1000 = 0.003 // Output cost: (15.0 / 1_000_000) * 500 = 0.0075 // Total: 0.003 + 0.0075 = 0.0105 - expect(result.totalCost).toBe(0.0105) - expect(result.totalInputTokens).toBe(1000) - expect(result.totalOutputTokens).toBe(500) + expect(result).toBe(0.0105) }) it("should handle missing cache prices", () => { @@ -217,9 +187,7 @@ describe("Cost Utility", () => { // Input cost: (3.0 / 1_000_000) * (6000 - 2000 - 3000) = 0.003 // Output cost: (15.0 / 1_000_000) * 500 = 0.0075 // Total: 0.003 + 0.0075 = 0.0105 - expect(result.totalCost).toBe(0.0105) - expect(result.totalInputTokens).toBe(6000) // Total already includes cache - expect(result.totalOutputTokens).toBe(500) + expect(result).toBe(0.0105) }) }) }) diff --git a/webview-ui/src/components/settings/ApiOptions.tsx b/webview-ui/src/components/settings/ApiOptions.tsx index e2e7ba561573..4c5a343bf3bb 100644 --- a/webview-ui/src/components/settings/ApiOptions.tsx +++ b/webview-ui/src/components/settings/ApiOptions.tsx @@ -235,7 +235,8 @@ const ApiOptions = ({ } else if ( selectedProvider === "litellm" || selectedProvider === "deepinfra" || - selectedProvider === "roo" + selectedProvider === "roo" || + selectedProvider === "xai" ) { vscode.postMessage({ type: "requestRouterModels" }) } @@ -252,6 +253,7 @@ const ApiOptions = ({ apiConfiguration?.litellmApiKey, apiConfiguration?.deepInfraApiKey, apiConfiguration?.deepInfraBaseUrl, + apiConfiguration?.xaiApiKey, customHeaders, ], ) @@ -609,7 +611,14 @@ const ApiOptions = ({ )} {selectedProvider === "xai" && ( - + )} {selectedProvider === "groq" && ( @@ -700,7 +709,7 @@ const ApiOptions = ({ )} - {selectedProviderModels.length > 0 && ( + {selectedProvider !== "xai" && selectedProviderModels.length > 0 && ( <>
diff --git a/webview-ui/src/components/settings/__tests__/ApiOptions.spec.tsx b/webview-ui/src/components/settings/__tests__/ApiOptions.spec.tsx index 7b7f9b33e48d..7c8f6b8c49fe 100644 --- a/webview-ui/src/components/settings/__tests__/ApiOptions.spec.tsx +++ b/webview-ui/src/components/settings/__tests__/ApiOptions.spec.tsx @@ -553,6 +553,32 @@ describe("ApiOptions", () => { expect(screen.getByTestId("litellm-refresh-models")).toBeInTheDocument() }) + it("hides generic Model picker when provider is xai", () => { + renderApiOptions({ + apiConfiguration: { + apiProvider: "xai", + }, + }) + // The generic "Model" label should be absent for xai (uses provider-specific picker) + expect(screen.queryByText("Model")).not.toBeInTheDocument() + }) + + it("disables xAI refresh and hides ModelPicker when no API key", () => { + renderApiOptions({ + apiConfiguration: { + apiProvider: "xai", + xaiApiKey: "", + }, + }) + // Generic Model picker should be hidden for xAI + expect(screen.queryByText("Model")).not.toBeInTheDocument() + // If the provider-specific refresh button is present, it should be disabled without a key + const btn = screen.queryByTestId("xai-refresh-models") + if (btn) { + expect(btn).toBeDisabled() + } + }) + it("does not render LiteLLM component when other provider is selected", () => { renderApiOptions({ apiConfiguration: { diff --git a/webview-ui/src/components/settings/constants.ts b/webview-ui/src/components/settings/constants.ts index a6631dfd66f4..7aa9920257b1 100644 --- a/webview-ui/src/components/settings/constants.ts +++ b/webview-ui/src/components/settings/constants.ts @@ -12,7 +12,6 @@ import { openAiNativeModels, qwenCodeModels, vertexModels, - xaiModels, groqModels, sambaNovaModels, doubaoModels, @@ -35,7 +34,6 @@ export const MODELS_BY_PROVIDER: Partial void + routerModels?: RouterModels + refetchRouterModels?: () => void + organizationAllowList?: OrganizationAllowList + modelValidationError?: string } -export const XAI = ({ apiConfiguration, setApiConfigurationField }: XAIProps) => { +export const XAI = ({ + apiConfiguration, + setApiConfigurationField, + routerModels, + refetchRouterModels, + organizationAllowList, + modelValidationError, +}: XAIProps) => { const { t } = useAppTranslation() + const [didRefetch, setDidRefetch] = useState() + const [refreshError, setRefreshError] = useState() const handleInputChange = useCallback( ( @@ -27,6 +47,56 @@ export const XAI = ({ apiConfiguration, setApiConfigurationField }: XAIProps) => [setApiConfigurationField], ) + const handleRefresh = useCallback(() => { + // Reset status and request fresh models + setDidRefetch(false) + setRefreshError(undefined) + + // Flush xAI cache and request fresh models + vscode.postMessage({ type: "flushRouterModels", text: "xai" }) + vscode.postMessage({ type: "requestRouterModels" }) + + // Allow consumer to refetch react-query if provided + refetchRouterModels?.() + }, [refetchRouterModels]) + + // Listen for router responses to determine success/failure + useEvent( + "message", + useCallback( + (event: MessageEvent) => { + const message: any = event.data + // Error channel: single provider failure + if (message?.type === "singleRouterModelFetchResponse" && message?.values?.provider === "xai") { + if (!message.success) { + setDidRefetch(false) + setRefreshError( + t("settings:providers.refreshModels.error") || + "Failed to fetch xAI models. Please verify your API key and try again.", + ) + } + } + + // Success path: routerModels set with non-empty xai models + if (message?.type === "routerModels") { + const models = message.routerModels?.xai ?? {} + if (models && Object.keys(models).length > 0) { + setRefreshError(undefined) + setDidRefetch(true) + } else if (apiConfiguration?.xaiApiKey) { + // With a key provided, an empty set indicates failure/unavailable + setDidRefetch(false) + setRefreshError( + t("settings:providers.refreshModels.error") || + "No xAI models found for this API key. Please verify your API key and try again.", + ) + } + } + }, + [apiConfiguration?.xaiApiKey, t], + ), + ) + return ( <> {t("settings:providers.apiKeyStorageNotice")}
{!apiConfiguration?.xaiApiKey && ( - + {t("settings:providers.getXaiApiKey")} )} + + {/* Refresh button is disabled without API key */} +
+ +
+ + {/* Status messaging */} + {refreshError &&
{refreshError}
} + {!refreshError && didRefetch && ( +
+ {t("settings:providers.refreshModels.success")} +
+ )} + + {/* Hide ModelPicker until an API key is provided */} + {apiConfiguration?.xaiApiKey ? ( + <> + + + {/* Context Window Override - only show for models not in static registry or with undefined contextWindow */} + {(() => { + const selectedModelId = apiConfiguration?.apiModelId || xaiDefaultModelId + const staticModel = xaiModels[selectedModelId as keyof typeof xaiModels] + const hasStaticContextWindow = staticModel?.contextWindow !== undefined + + if (!hasStaticContextWindow) { + return ( + <> + { + const v = (e.target as HTMLInputElement).value.trim() + const n = Number(v) + return Number.isFinite(n) && n > 0 ? Math.floor(n) : undefined + })} + placeholder="e.g., 256000" + className="w-full mt-4"> + + +
+ This model's context window is not known. Please enter it manually. +
+ + ) + } + return null + })()} + + ) : ( +
+ {t("settings:providers.refreshModels.missingConfig")} +
+ )} ) } diff --git a/webview-ui/src/components/ui/hooks/useSelectedModel.ts b/webview-ui/src/components/ui/hooks/useSelectedModel.ts index 296b262c3731..1ef1a1f079db 100644 --- a/webview-ui/src/components/ui/hooks/useSelectedModel.ts +++ b/webview-ui/src/components/ui/hooks/useSelectedModel.ts @@ -182,8 +182,35 @@ function getSelectedModel({ } case "xai": { const id = apiConfiguration.apiModelId ?? xaiDefaultModelId - const info = xaiModels[id as keyof typeof xaiModels] - return info ? { id, info } : { id, info: undefined } + const dynamicInfo = routerModels.xai?.[id] + if (dynamicInfo) { + // If router-provided contextWindow is missing or invalid (<= 0), apply manual override when provided + const overrideCw = apiConfiguration.xaiModelContextWindow + const info = + !(typeof dynamicInfo.contextWindow === "number" && dynamicInfo.contextWindow > 0) && + typeof overrideCw === "number" + ? { ...dynamicInfo, contextWindow: overrideCw } + : dynamicInfo + return { id, info } + } + const staticInfo = xaiModels[id as keyof typeof xaiModels] + // Build a complete ModelInfo fallback to satisfy UI expectations until dynamic models load + const info: ModelInfo = { + ...openAiModelInfoSaneDefaults, + contextWindow: + apiConfiguration.xaiModelContextWindow ?? + staticInfo?.contextWindow ?? + openAiModelInfoSaneDefaults.contextWindow, + maxTokens: staticInfo?.maxTokens ?? openAiModelInfoSaneDefaults.maxTokens, + supportsPromptCache: false, // Placeholder; dynamic API will provide real value + supportsImages: false, // Placeholder; dynamic API will provide real value + description: staticInfo?.description, + supportsReasoningEffort: + staticInfo && "supportsReasoningEffort" in staticInfo + ? staticInfo.supportsReasoningEffort + : undefined, + } + return { id, info } } case "groq": { const id = apiConfiguration.apiModelId ?? groqDefaultModelId diff --git a/webview-ui/src/utils/__tests__/validate.test.ts b/webview-ui/src/utils/__tests__/validate.test.ts index 0bd7a15962b5..fdce40d7ae13 100644 --- a/webview-ui/src/utils/__tests__/validate.test.ts +++ b/webview-ui/src/utils/__tests__/validate.test.ts @@ -45,6 +45,7 @@ describe("Model Validation Functions", () => { huggingface: {}, roo: {}, chutes: {}, + xai: {}, } const allowAllOrganization: OrganizationAllowList = {