Skip to content

Commit d78431b

Browse files
committed
feat: add dynamic model discovery for xAI provider
- Fetch models from xAI API endpoints instead of hard-coding IDs - Add model caching with automatic refresh capability - Update xAI settings UI with model picker and refresh button - Disable reasoning_effort parameter for incompatible models - Derive prompt cache support from API pricing data
1 parent c232057 commit d78431b

File tree

14 files changed

+432
-80
lines changed

14 files changed

+432
-80
lines changed

packages/types/src/provider-settings.ts

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ export const dynamicProviders = [
4949
"unbound",
5050
"glama",
5151
"roo",
52+
"xai",
5253
] as const
5354

5455
export type DynamicProvider = (typeof dynamicProviders)[number]
@@ -136,7 +137,6 @@ export const providerNames = [
136137
"roo",
137138
"sambanova",
138139
"vertex",
139-
"xai",
140140
"zai",
141141
] as const
142142

@@ -346,6 +346,7 @@ const fakeAiSchema = baseProviderSettingsSchema.extend({
346346

347347
const xaiSchema = apiModelIdProviderModelSchema.extend({
348348
xaiApiKey: z.string().optional(),
349+
xaiModelContextWindow: z.number().optional(),
349350
})
350351

351352
const groqSchema = apiModelIdProviderModelSchema.extend({
@@ -698,7 +699,7 @@ export const MODELS_BY_PROVIDER: Record<
698699
label: "VS Code LM API",
699700
models: Object.keys(vscodeLlmModels),
700701
},
701-
xai: { id: "xai", label: "xAI (Grok)", models: Object.keys(xaiModels) },
702+
xai: { id: "xai", label: "xAI (Grok)", models: [] },
702703
zai: { id: "zai", label: "Zai", models: Object.keys(internationalZAiModels) },
703704

704705
// Dynamic providers; models pulled from remote APIs.

packages/types/src/providers/xai.ts

Lines changed: 36 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -3,93 +3,63 @@ import type { ModelInfo } from "../model.js"
33
// https://docs.x.ai/docs/api-reference
44
export type XAIModelId = keyof typeof xaiModels
55

6-
export const xaiDefaultModelId: XAIModelId = "grok-code-fast-1"
6+
export const xaiDefaultModelId: XAIModelId = "grok-4-fast-reasoning"
7+
8+
/**
9+
* Partial ModelInfo for xAI static registry.
10+
* Contains only fields not available from the xAI API:
11+
* - contextWindow: Not provided by API
12+
* - maxTokens: Not provided by API
13+
* - description: User-friendly descriptions
14+
* - supportsReasoningEffort: Special capability flag
15+
*
16+
* All other fields (pricing, supportsPromptCache, supportsImages) are fetched dynamically.
17+
*/
18+
type XAIStaticModelInfo = Pick<ModelInfo, "contextWindow" | "description"> & {
19+
maxTokens?: number | null
20+
supportsReasoningEffort?: boolean
21+
}
722

823
export const xaiModels = {
924
"grok-code-fast-1": {
1025
maxTokens: 16_384,
11-
contextWindow: 262_144,
12-
supportsImages: false,
13-
supportsPromptCache: true,
14-
inputPrice: 0.2,
15-
outputPrice: 1.5,
16-
cacheWritesPrice: 0.02,
17-
cacheReadsPrice: 0.02,
26+
contextWindow: 256_000,
1827
description: "xAI's Grok Code Fast model with 256K context window",
1928
},
20-
"grok-4": {
21-
maxTokens: 8192,
22-
contextWindow: 256000,
23-
supportsImages: true,
24-
supportsPromptCache: true,
25-
inputPrice: 3.0,
26-
outputPrice: 15.0,
27-
cacheWritesPrice: 0.75,
28-
cacheReadsPrice: 0.75,
29+
"grok-4-0709": {
30+
maxTokens: 16_384,
31+
contextWindow: 256_000,
2932
description: "xAI's Grok-4 model with 256K context window",
3033
},
34+
"grok-4-fast-non-reasoning": {
35+
maxTokens: 32_768,
36+
contextWindow: 2_000_000,
37+
description: "xAI's Grok-4 Fast Non-Reasoning model with 2M context window",
38+
},
39+
"grok-4-fast-reasoning": {
40+
maxTokens: 32_768,
41+
contextWindow: 2_000_000,
42+
description: "xAI's Grok-4 Fast Reasoning model with 2M context window",
43+
},
3144
"grok-3": {
3245
maxTokens: 8192,
33-
contextWindow: 131072,
34-
supportsImages: false,
35-
supportsPromptCache: true,
36-
inputPrice: 3.0,
37-
outputPrice: 15.0,
38-
cacheWritesPrice: 0.75,
39-
cacheReadsPrice: 0.75,
46+
contextWindow: 131_072,
4047
description: "xAI's Grok-3 model with 128K context window",
4148
},
42-
"grok-3-fast": {
43-
maxTokens: 8192,
44-
contextWindow: 131072,
45-
supportsImages: false,
46-
supportsPromptCache: true,
47-
inputPrice: 5.0,
48-
outputPrice: 25.0,
49-
cacheWritesPrice: 1.25,
50-
cacheReadsPrice: 1.25,
51-
description: "xAI's Grok-3 fast model with 128K context window",
52-
},
5349
"grok-3-mini": {
5450
maxTokens: 8192,
55-
contextWindow: 131072,
56-
supportsImages: false,
57-
supportsPromptCache: true,
58-
inputPrice: 0.3,
59-
outputPrice: 0.5,
60-
cacheWritesPrice: 0.07,
61-
cacheReadsPrice: 0.07,
51+
contextWindow: 131_072,
6252
description: "xAI's Grok-3 mini model with 128K context window",
6353
supportsReasoningEffort: true,
6454
},
65-
"grok-3-mini-fast": {
66-
maxTokens: 8192,
67-
contextWindow: 131072,
68-
supportsImages: false,
69-
supportsPromptCache: true,
70-
inputPrice: 0.6,
71-
outputPrice: 4.0,
72-
cacheWritesPrice: 0.15,
73-
cacheReadsPrice: 0.15,
74-
description: "xAI's Grok-3 mini fast model with 128K context window",
75-
supportsReasoningEffort: true,
76-
},
7755
"grok-2-1212": {
7856
maxTokens: 8192,
79-
contextWindow: 131072,
80-
supportsImages: false,
81-
supportsPromptCache: false,
82-
inputPrice: 2.0,
83-
outputPrice: 10.0,
84-
description: "xAI's Grok-2 model (version 1212) with 128K context window",
57+
contextWindow: 32_768,
58+
description: "xAI's Grok-2 model (version 1212) with 32K context window",
8559
},
8660
"grok-2-vision-1212": {
8761
maxTokens: 8192,
88-
contextWindow: 32768,
89-
supportsImages: true,
90-
supportsPromptCache: false,
91-
inputPrice: 2.0,
92-
outputPrice: 10.0,
62+
contextWindow: 32_768,
9363
description: "xAI's Grok-2 Vision model (version 1212) with image support and 32K context window",
9464
},
95-
} as const satisfies Record<string, ModelInfo>
65+
} as const satisfies Record<string, XAIStaticModelInfo>

src/api/providers/__tests__/xai.spec.ts

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,12 @@ describe("XAIHandler", () => {
5757
it("should return default model when no model is specified", () => {
5858
const model = handler.getModel()
5959
expect(model.id).toBe(xaiDefaultModelId)
60-
expect(model.info).toEqual(xaiModels[xaiDefaultModelId])
60+
expect(model.info).toMatchObject({
61+
contextWindow: xaiModels[xaiDefaultModelId].contextWindow,
62+
maxTokens: xaiModels[xaiDefaultModelId].maxTokens,
63+
description: xaiModels[xaiDefaultModelId].description,
64+
})
65+
expect(model.info.supportsPromptCache).toBe(false) // Placeholder until dynamic data loads
6166
})
6267

6368
test("should return specified model when valid model is provided", () => {
@@ -66,7 +71,12 @@ describe("XAIHandler", () => {
6671
const model = handlerWithModel.getModel()
6772

6873
expect(model.id).toBe(testModelId)
69-
expect(model.info).toEqual(xaiModels[testModelId])
74+
expect(model.info).toMatchObject({
75+
contextWindow: xaiModels[testModelId].contextWindow,
76+
maxTokens: xaiModels[testModelId].maxTokens,
77+
description: xaiModels[testModelId].description,
78+
})
79+
expect(model.info.supportsPromptCache).toBe(false) // Placeholder until dynamic data loads
7080
})
7181

7282
it("should include reasoning_effort parameter for mini models", async () => {
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import { describe, it, expect, vi, beforeEach } from "vitest"
2+
import axios from "axios"
3+
4+
vi.mock("axios")
5+
6+
import { getXaiModels } from "../xai"
7+
import { xaiModels } from "@roo-code/types"
8+
9+
describe("getXaiModels", () => {
10+
const mockedAxios = axios as unknown as { get: ReturnType<typeof vi.fn> }
11+
12+
beforeEach(() => {
13+
vi.clearAllMocks()
14+
})
15+
16+
it("returns mapped models with pricing and modalities (augmenting static info when available)", async () => {
17+
mockedAxios.get = vi.fn().mockResolvedValue({
18+
data: {
19+
models: [
20+
{
21+
id: "grok-3",
22+
input_modalities: ["text"],
23+
output_modalities: ["text"],
24+
prompt_text_token_price: 30000,
25+
cached_prompt_text_token_price: 7500,
26+
completion_text_token_price: 150000,
27+
aliases: ["grok-3-latest"],
28+
},
29+
],
30+
},
31+
})
32+
33+
const result = await getXaiModels("key", "https://api.x.ai/v1")
34+
expect(result["grok-3"]).toBeDefined()
35+
expect(result["grok-3"]?.supportsImages).toBe(false)
36+
expect(result["grok-3"]?.inputPrice).toBeCloseTo(300) // $300 per 1M (cents->dollars)
37+
expect(result["grok-3"]?.outputPrice).toBeCloseTo(1500)
38+
expect(result["grok-3"]?.cacheReadsPrice).toBeCloseTo(75)
39+
// aliases are not added to avoid UI duplication
40+
expect(result["grok-3-latest"]).toBeUndefined()
41+
})
42+
43+
it("returns empty object on schema mismatches (graceful degradation)", async () => {
44+
mockedAxios.get = vi.fn().mockResolvedValue({
45+
data: { data: [{ bogus: true }] },
46+
})
47+
const result = await getXaiModels("key")
48+
expect(result).toEqual({})
49+
})
50+
51+
it("includes Authorization header when apiKey provided", async () => {
52+
mockedAxios.get = vi.fn().mockResolvedValue({ data: { data: [] } })
53+
await getXaiModels("secret")
54+
expect((axios.get as any).mock.calls[0][1].headers.Authorization).toBe("Bearer secret")
55+
})
56+
})

src/api/providers/fetchers/modelCache.ts

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import { getIOIntelligenceModels } from "./io-intelligence"
2525
import { getDeepInfraModels } from "./deepinfra"
2626
import { getHuggingFaceModels } from "./huggingface"
2727
import { getRooModels } from "./roo"
28+
import { getXaiModels } from "./xai"
2829

2930
const memoryCache = new NodeCache({ stdTTL: 5 * 60, checkperiod: 5 * 60 })
3031

@@ -100,6 +101,9 @@ export const getModels = async (options: GetModelsOptions): Promise<ModelRecord>
100101
case "huggingface":
101102
models = await getHuggingFaceModels()
102103
break
104+
case "xai":
105+
models = await getXaiModels(options.apiKey, options.baseUrl)
106+
break
103107
case "roo": {
104108
// Roo Code Cloud provider requires baseUrl and optional apiKey
105109
const rooBaseUrl =
@@ -117,7 +121,7 @@ export const getModels = async (options: GetModelsOptions): Promise<ModelRecord>
117121
// Cache the fetched models (even if empty, to signify a successful fetch with no models).
118122
memoryCache.set(provider, models)
119123

120-
await writeModels(provider, models).catch((err) =>
124+
await writeModels(provider, models || {}).catch((err) =>
121125
console.error(`[getModels] Error writing ${provider} models to file cache:`, err),
122126
)
123127

src/api/providers/fetchers/xai.ts

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
import axios from "axios"
2+
import { z } from "zod"
3+
4+
import { type ModelInfo, xaiModels } from "@roo-code/types"
5+
import { DEFAULT_HEADERS } from "../../providers/constants"
6+
7+
/**
8+
* Schema for GET https://api.x.ai/v1/language-models
9+
* This endpoint returns rich metadata including modalities and pricing.
10+
*/
11+
const xaiLanguageModelSchema = z.object({
12+
id: z.string(),
13+
input_modalities: z.array(z.string()).optional(),
14+
output_modalities: z.array(z.string()).optional(),
15+
prompt_text_token_price: z.number().optional(), // cents per 1M tokens
16+
cached_prompt_text_token_price: z.number().optional(), // cents per 1M tokens
17+
prompt_image_token_price: z.number().optional(), // cents per 1M tokens
18+
completion_text_token_price: z.number().optional(), // cents per 1M tokens
19+
search_price: z.number().optional(),
20+
aliases: z.array(z.string()).optional(),
21+
})
22+
23+
const xaiLanguageModelsResponseSchema = z.object({
24+
models: z.array(xaiLanguageModelSchema),
25+
})
26+
27+
/**
28+
* Fetch available xAI models for the authenticated account.
29+
* - Uses Bearer Authorization header when apiKey is provided
30+
* - Maps discovered IDs to ModelInfo using static catalog (xaiModels) when possible
31+
* - For models not in static catalog, contextWindow and maxTokens remain undefined
32+
*/
33+
export async function getXaiModels(apiKey?: string, baseUrl?: string): Promise<Record<string, ModelInfo>> {
34+
const models: Record<string, ModelInfo> = {}
35+
// Build proper endpoint whether user passes https://api.x.ai or https://api.x.ai/v1
36+
const base = baseUrl ? baseUrl.replace(/\/+$/, "") : "https://api.x.ai"
37+
const url = base.endsWith("/v1") ? `${base}/language-models` : `${base}/v1/language-models`
38+
39+
try {
40+
const resp = await axios.get(url, {
41+
headers: {
42+
...DEFAULT_HEADERS,
43+
Accept: "application/json",
44+
...(apiKey ? { Authorization: `Bearer ${apiKey}` } : {}),
45+
},
46+
})
47+
48+
const parsed = xaiLanguageModelsResponseSchema.safeParse(resp.data)
49+
const items = parsed.success
50+
? parsed.data.models
51+
: Array.isArray((resp.data as any)?.models)
52+
? (resp.data as any)?.models
53+
: []
54+
55+
if (!parsed.success) {
56+
console.error("xAI language models response validation failed", parsed.error?.format?.() ?? parsed.error)
57+
}
58+
59+
// Helper to convert cents-per-1M to dollars-per-1M (assumption per API examples)
60+
const centsToDollars = (v?: number) => (typeof v === "number" ? v / 100 : undefined)
61+
62+
for (const m of items) {
63+
const id = m.id
64+
const staticInfo = xaiModels[id as keyof typeof xaiModels]
65+
const supportsImages = Array.isArray(m.input_modalities) ? m.input_modalities.includes("image") : false
66+
67+
// Cache support is indicated by presence of cached_prompt_text_token_price field (even if 0)
68+
const supportsPromptCache = typeof m.cached_prompt_text_token_price === "number"
69+
const cacheReadsPrice = supportsPromptCache ? centsToDollars(m.cached_prompt_text_token_price) : undefined
70+
71+
const info: ModelInfo = {
72+
maxTokens: staticInfo?.maxTokens ?? undefined,
73+
contextWindow: staticInfo?.contextWindow ?? undefined,
74+
supportsImages,
75+
supportsPromptCache,
76+
inputPrice: centsToDollars(m.prompt_text_token_price),
77+
outputPrice: centsToDollars(m.completion_text_token_price),
78+
cacheReadsPrice,
79+
cacheWritesPrice: cacheReadsPrice, // xAI uses same price for reads and writes
80+
description: staticInfo?.description,
81+
supportsReasoningEffort:
82+
staticInfo && "supportsReasoningEffort" in staticInfo
83+
? staticInfo.supportsReasoningEffort
84+
: undefined,
85+
// leave other optional fields undefined unless available via static definitions
86+
}
87+
88+
models[id] = info
89+
// Aliases are not added to the model list to avoid duplication in UI
90+
// Users should use the primary model ID; xAI API will handle alias resolution
91+
}
92+
} catch (error) {
93+
try {
94+
const err = JSON.stringify(error, Object.getOwnPropertyNames(error), 2)
95+
console.error(`[xAI] models fetch failed: ${err}`)
96+
} catch {
97+
console.error("[xAI] models fetch failed.")
98+
}
99+
throw error
100+
}
101+
102+
return models
103+
}

0 commit comments

Comments
 (0)