Skip to content

Commit ba6a37b

Browse files
committed
fix(xai): compute totalCost from usage and load dynamic pricing; update test
1 parent 622bc4a commit ba6a37b

File tree

2 files changed

+36
-3
lines changed

2 files changed

+36
-3
lines changed

src/api/providers/__tests__/xai.spec.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -244,12 +244,13 @@ describe("XAIHandler", () => {
244244

245245
// Verify the usage data
246246
expect(firstChunk.done).toBe(false)
247-
expect(firstChunk.value).toEqual({
247+
expect(firstChunk.value).toMatchObject({
248248
type: "usage",
249249
inputTokens: 10,
250250
outputTokens: 20,
251251
cacheReadTokens: 5,
252252
cacheWriteTokens: 15,
253+
totalCost: expect.any(Number),
253254
})
254255
})
255256

src/api/providers/xai.ts

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,17 @@ import { DEFAULT_HEADERS } from "./constants"
1313
import { BaseProvider } from "./base-provider"
1414
import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index"
1515
import { handleOpenAIError } from "./utils/openai-error-handler"
16+
import { calculateApiCostOpenAI } from "../../shared/cost"
17+
import type { ModelRecord } from "../../shared/api"
18+
import { getModels } from "./fetchers/modelCache"
1619

1720
const XAI_DEFAULT_TEMPERATURE = 0
1821

1922
export class XAIHandler extends BaseProvider implements SingleCompletionHandler {
2023
protected options: ApiHandlerOptions
2124
private client: OpenAI
2225
private readonly providerName = "xAI"
26+
protected models: ModelRecord = {}
2327

2428
constructor(options: ApiHandlerOptions) {
2529
super()
@@ -39,12 +43,18 @@ export class XAIHandler extends BaseProvider implements SingleCompletionHandler
3943
const id = this.options.apiModelId ?? xaiDefaultModelId
4044

4145
const staticInfo = (xaiModels as Record<string, any>)[id as any]
46+
const dynamicInfo = this.models?.[id as any]
4247

43-
// Build complete ModelInfo with required fields; dynamic data comes from router models
48+
// Build complete ModelInfo using dynamic pricing/capabilities when available
4449
const info: ModelInfo = {
4550
contextWindow: this.options.xaiModelContextWindow ?? staticInfo?.contextWindow,
4651
maxTokens: staticInfo?.maxTokens ?? undefined,
47-
supportsPromptCache: false, // Placeholder - actual value comes from dynamic API call
52+
supportsPromptCache: dynamicInfo?.supportsPromptCache ?? false,
53+
supportsImages: dynamicInfo?.supportsImages,
54+
inputPrice: dynamicInfo?.inputPrice,
55+
outputPrice: dynamicInfo?.outputPrice,
56+
cacheReadsPrice: dynamicInfo?.cacheReadsPrice,
57+
cacheWritesPrice: dynamicInfo?.cacheWritesPrice,
4858
description: staticInfo?.description,
4959
supportsReasoningEffort:
5060
staticInfo && "supportsReasoningEffort" in staticInfo ? staticInfo.supportsReasoningEffort : undefined,
@@ -54,11 +64,24 @@ export class XAIHandler extends BaseProvider implements SingleCompletionHandler
5464
return { id, info, ...params }
5565
}
5666

67+
private async loadDynamicModels(): Promise<void> {
68+
try {
69+
this.models = await getModels({
70+
provider: "xai",
71+
apiKey: this.options.xaiApiKey,
72+
baseUrl: (this.client as any).baseURL || "https://api.x.ai/v1",
73+
})
74+
} catch (error) {
75+
console.error("[XAI] Error loading dynamic models:", error)
76+
}
77+
}
78+
5779
override async *createMessage(
5880
systemPrompt: string,
5981
messages: Anthropic.Messages.MessageParam[],
6082
metadata?: ApiHandlerCreateMessageMetadata,
6183
): ApiStream {
84+
await this.loadDynamicModels()
6285
const { id: modelId, info: modelInfo, reasoning } = this.getModel()
6386

6487
// Use the OpenAI-compatible API.
@@ -107,12 +130,21 @@ export class XAIHandler extends BaseProvider implements SingleCompletionHandler
107130
const writeTokens =
108131
"cache_creation_input_tokens" in chunk.usage ? (chunk.usage as any).cache_creation_input_tokens : 0
109132

133+
const totalCost = calculateApiCostOpenAI(
134+
modelInfo,
135+
chunk.usage.prompt_tokens || 0,
136+
chunk.usage.completion_tokens || 0,
137+
writeTokens || 0,
138+
readTokens || 0,
139+
)
140+
110141
yield {
111142
type: "usage",
112143
inputTokens: chunk.usage.prompt_tokens || 0,
113144
outputTokens: chunk.usage.completion_tokens || 0,
114145
cacheReadTokens: readTokens,
115146
cacheWriteTokens: writeTokens,
147+
totalCost,
116148
}
117149
}
118150
}

0 commit comments

Comments
 (0)