Skip to content

Commit a921d05

Browse files
juesooyama777
andauthored
Add Z AI provider (#6657)
Co-authored-by: wangshan <[email protected]>
1 parent f24c1e6 commit a921d05

File tree

30 files changed

+558
-0
lines changed

30 files changed

+558
-0
lines changed

packages/types/src/provider-settings.ts

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ export const providerNames = [
3636
"huggingface",
3737
"cerebras",
3838
"sambanova",
39+
"zai",
3940
] as const
4041

4142
export const providerNamesSchema = z.enum(providerNames)
@@ -257,6 +258,11 @@ const sambaNovaSchema = apiModelIdProviderModelSchema.extend({
257258
sambaNovaApiKey: z.string().optional(),
258259
})
259260

261+
const zaiSchema = apiModelIdProviderModelSchema.extend({
262+
zaiApiKey: z.string().optional(),
263+
zaiApiLine: z.union([z.literal("china"), z.literal("international")]).optional(),
264+
})
265+
260266
const defaultSchema = z.object({
261267
apiProvider: z.undefined(),
262268
})
@@ -290,6 +296,7 @@ export const providerSettingsSchemaDiscriminated = z.discriminatedUnion("apiProv
290296
litellmSchema.merge(z.object({ apiProvider: z.literal("litellm") })),
291297
cerebrasSchema.merge(z.object({ apiProvider: z.literal("cerebras") })),
292298
sambaNovaSchema.merge(z.object({ apiProvider: z.literal("sambanova") })),
299+
zaiSchema.merge(z.object({ apiProvider: z.literal("zai") })),
293300
defaultSchema,
294301
])
295302

@@ -323,6 +330,7 @@ export const providerSettingsSchema = z.object({
323330
...litellmSchema.shape,
324331
...cerebrasSchema.shape,
325332
...sambaNovaSchema.shape,
333+
...zaiSchema.shape,
326334
...codebaseIndexProviderSchema.shape,
327335
})
328336

packages/types/src/providers/index.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,4 @@ export * from "./vertex.js"
2222
export * from "./vscode-llm.js"
2323
export * from "./xai.js"
2424
export * from "./doubao.js"
25+
export * from "./zai.js"
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
import type { ModelInfo } from "../model.js"
2+
3+
// Z AI
4+
// https://docs.z.ai/guides/llm/glm-4.5
5+
// https://docs.z.ai/guides/overview/pricing
6+
7+
export type InternationalZAiModelId = keyof typeof internationalZAiModels
8+
export const internationalZAiDefaultModelId: InternationalZAiModelId = "glm-4.5"
9+
export const internationalZAiModels = {
10+
"glm-4.5": {
11+
maxTokens: 98_304,
12+
contextWindow: 131_072,
13+
supportsImages: false,
14+
supportsPromptCache: true,
15+
inputPrice: 0.6,
16+
outputPrice: 2.2,
17+
cacheWritesPrice: 0,
18+
cacheReadsPrice: 0.11,
19+
description:
20+
"GLM-4.5 is Zhipu's latest featured model. Its comprehensive capabilities in reasoning, coding, and agent reach the state-of-the-art (SOTA) level among open-source models, with a context length of up to 128k.",
21+
},
22+
"glm-4.5-air": {
23+
maxTokens: 98_304,
24+
contextWindow: 131_072,
25+
supportsImages: false,
26+
supportsPromptCache: true,
27+
inputPrice: 0.2,
28+
outputPrice: 1.1,
29+
cacheWritesPrice: 0,
30+
cacheReadsPrice: 0.03,
31+
description:
32+
"GLM-4.5-Air is the lightweight version of GLM-4.5. It balances performance and cost-effectiveness, and can flexibly switch to hybrid thinking models.",
33+
},
34+
} as const satisfies Record<string, ModelInfo>
35+
36+
export type MainlandZAiModelId = keyof typeof mainlandZAiModels
37+
export const mainlandZAiDefaultModelId: MainlandZAiModelId = "glm-4.5"
38+
export const mainlandZAiModels = {
39+
"glm-4.5": {
40+
maxTokens: 98_304,
41+
contextWindow: 131_072,
42+
supportsImages: false,
43+
supportsPromptCache: true,
44+
inputPrice: 0.29,
45+
outputPrice: 1.14,
46+
cacheWritesPrice: 0,
47+
cacheReadsPrice: 0.057,
48+
description:
49+
"GLM-4.5 is Zhipu's latest featured model. Its comprehensive capabilities in reasoning, coding, and agent reach the state-of-the-art (SOTA) level among open-source models, with a context length of up to 128k.",
50+
tiers: [
51+
{
52+
contextWindow: 32_000,
53+
inputPrice: 0.21,
54+
outputPrice: 1.0,
55+
cacheReadsPrice: 0.043,
56+
},
57+
{
58+
contextWindow: 128_000,
59+
inputPrice: 0.29,
60+
outputPrice: 1.14,
61+
cacheReadsPrice: 0.057,
62+
},
63+
{
64+
contextWindow: Infinity,
65+
inputPrice: 0.29,
66+
outputPrice: 1.14,
67+
cacheReadsPrice: 0.057,
68+
},
69+
],
70+
},
71+
"glm-4.5-air": {
72+
maxTokens: 98_304,
73+
contextWindow: 131_072,
74+
supportsImages: false,
75+
supportsPromptCache: true,
76+
inputPrice: 0.1,
77+
outputPrice: 0.6,
78+
cacheWritesPrice: 0,
79+
cacheReadsPrice: 0.02,
80+
description:
81+
"GLM-4.5-Air is the lightweight version of GLM-4.5. It balances performance and cost-effectiveness, and can flexibly switch to hybrid thinking models.",
82+
tiers: [
83+
{
84+
contextWindow: 32_000,
85+
inputPrice: 0.07,
86+
outputPrice: 0.4,
87+
cacheReadsPrice: 0.014,
88+
},
89+
{
90+
contextWindow: 128_000,
91+
inputPrice: 0.1,
92+
outputPrice: 0.6,
93+
cacheReadsPrice: 0.02,
94+
},
95+
{
96+
contextWindow: Infinity,
97+
inputPrice: 0.1,
98+
outputPrice: 0.6,
99+
cacheReadsPrice: 0.02,
100+
},
101+
],
102+
},
103+
} as const satisfies Record<string, ModelInfo>
104+
105+
export const ZAI_DEFAULT_TEMPERATURE = 0

src/api/index.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ import {
3333
ClaudeCodeHandler,
3434
SambaNovaHandler,
3535
DoubaoHandler,
36+
ZAiHandler,
3637
} from "./providers"
3738

3839
export interface SingleCompletionHandler {
@@ -124,6 +125,8 @@ export function buildApiHandler(configuration: ProviderSettings): ApiHandler {
124125
return new CerebrasHandler(options)
125126
case "sambanova":
126127
return new SambaNovaHandler(options)
128+
case "zai":
129+
return new ZAiHandler(options)
127130
default:
128131
apiProvider satisfies "gemini-cli" | undefined
129132
return new AnthropicHandler(options)
Lines changed: 231 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,231 @@
1+
// npx vitest run src/api/providers/__tests__/zai.spec.ts
2+
3+
// Mock vscode first to avoid import errors
4+
vitest.mock("vscode", () => ({}))
5+
6+
import OpenAI from "openai"
7+
import { Anthropic } from "@anthropic-ai/sdk"
8+
9+
import {
10+
type InternationalZAiModelId,
11+
type MainlandZAiModelId,
12+
internationalZAiDefaultModelId,
13+
mainlandZAiDefaultModelId,
14+
internationalZAiModels,
15+
mainlandZAiModels,
16+
ZAI_DEFAULT_TEMPERATURE,
17+
} from "@roo-code/types"
18+
19+
import { ZAiHandler } from "../zai"
20+
21+
vitest.mock("openai", () => {
22+
const createMock = vitest.fn()
23+
return {
24+
default: vitest.fn(() => ({ chat: { completions: { create: createMock } } })),
25+
}
26+
})
27+
28+
describe("ZAiHandler", () => {
29+
let handler: ZAiHandler
30+
let mockCreate: any
31+
32+
beforeEach(() => {
33+
vitest.clearAllMocks()
34+
mockCreate = (OpenAI as unknown as any)().chat.completions.create
35+
})
36+
37+
describe("International Z AI", () => {
38+
beforeEach(() => {
39+
handler = new ZAiHandler({ zaiApiKey: "test-zai-api-key", zaiApiLine: "international" })
40+
})
41+
42+
it("should use the correct international Z AI base URL", () => {
43+
new ZAiHandler({ zaiApiKey: "test-zai-api-key", zaiApiLine: "international" })
44+
expect(OpenAI).toHaveBeenCalledWith(expect.objectContaining({ baseURL: "https://api.z.ai/api/paas/v4" }))
45+
})
46+
47+
it("should use the provided API key for international", () => {
48+
const zaiApiKey = "test-zai-api-key"
49+
new ZAiHandler({ zaiApiKey, zaiApiLine: "international" })
50+
expect(OpenAI).toHaveBeenCalledWith(expect.objectContaining({ apiKey: zaiApiKey }))
51+
})
52+
53+
it("should return international default model when no model is specified", () => {
54+
const model = handler.getModel()
55+
expect(model.id).toBe(internationalZAiDefaultModelId)
56+
expect(model.info).toEqual(internationalZAiModels[internationalZAiDefaultModelId])
57+
})
58+
59+
it("should return specified international model when valid model is provided", () => {
60+
const testModelId: InternationalZAiModelId = "glm-4.5-air"
61+
const handlerWithModel = new ZAiHandler({
62+
apiModelId: testModelId,
63+
zaiApiKey: "test-zai-api-key",
64+
zaiApiLine: "international",
65+
})
66+
const model = handlerWithModel.getModel()
67+
expect(model.id).toBe(testModelId)
68+
expect(model.info).toEqual(internationalZAiModels[testModelId])
69+
})
70+
})
71+
72+
describe("China Z AI", () => {
73+
beforeEach(() => {
74+
handler = new ZAiHandler({ zaiApiKey: "test-zai-api-key", zaiApiLine: "china" })
75+
})
76+
77+
it("should use the correct China Z AI base URL", () => {
78+
new ZAiHandler({ zaiApiKey: "test-zai-api-key", zaiApiLine: "china" })
79+
expect(OpenAI).toHaveBeenCalledWith(
80+
expect.objectContaining({ baseURL: "https://open.bigmodel.cn/api/paas/v4" }),
81+
)
82+
})
83+
84+
it("should use the provided API key for China", () => {
85+
const zaiApiKey = "test-zai-api-key"
86+
new ZAiHandler({ zaiApiKey, zaiApiLine: "china" })
87+
expect(OpenAI).toHaveBeenCalledWith(expect.objectContaining({ apiKey: zaiApiKey }))
88+
})
89+
90+
it("should return China default model when no model is specified", () => {
91+
const model = handler.getModel()
92+
expect(model.id).toBe(mainlandZAiDefaultModelId)
93+
expect(model.info).toEqual(mainlandZAiModels[mainlandZAiDefaultModelId])
94+
})
95+
96+
it("should return specified China model when valid model is provided", () => {
97+
const testModelId: MainlandZAiModelId = "glm-4.5-air"
98+
const handlerWithModel = new ZAiHandler({
99+
apiModelId: testModelId,
100+
zaiApiKey: "test-zai-api-key",
101+
zaiApiLine: "china",
102+
})
103+
const model = handlerWithModel.getModel()
104+
expect(model.id).toBe(testModelId)
105+
expect(model.info).toEqual(mainlandZAiModels[testModelId])
106+
})
107+
})
108+
109+
describe("Default behavior", () => {
110+
it("should default to international when no zaiApiLine is specified", () => {
111+
const handlerDefault = new ZAiHandler({ zaiApiKey: "test-zai-api-key" })
112+
expect(OpenAI).toHaveBeenCalledWith(expect.objectContaining({ baseURL: "https://api.z.ai/api/paas/v4" }))
113+
114+
const model = handlerDefault.getModel()
115+
expect(model.id).toBe(internationalZAiDefaultModelId)
116+
expect(model.info).toEqual(internationalZAiModels[internationalZAiDefaultModelId])
117+
})
118+
119+
it("should use 'not-provided' as default API key when none is specified", () => {
120+
new ZAiHandler({ zaiApiLine: "international" })
121+
expect(OpenAI).toHaveBeenCalledWith(expect.objectContaining({ apiKey: "not-provided" }))
122+
})
123+
})
124+
125+
describe("API Methods", () => {
126+
beforeEach(() => {
127+
handler = new ZAiHandler({ zaiApiKey: "test-zai-api-key", zaiApiLine: "international" })
128+
})
129+
130+
it("completePrompt method should return text from Z AI API", async () => {
131+
const expectedResponse = "This is a test response from Z AI"
132+
mockCreate.mockResolvedValueOnce({ choices: [{ message: { content: expectedResponse } }] })
133+
const result = await handler.completePrompt("test prompt")
134+
expect(result).toBe(expectedResponse)
135+
})
136+
137+
it("should handle errors in completePrompt", async () => {
138+
const errorMessage = "Z AI API error"
139+
mockCreate.mockRejectedValueOnce(new Error(errorMessage))
140+
await expect(handler.completePrompt("test prompt")).rejects.toThrow(
141+
`Z AI completion error: ${errorMessage}`,
142+
)
143+
})
144+
145+
it("createMessage should yield text content from stream", async () => {
146+
const testContent = "This is test content from Z AI stream"
147+
148+
mockCreate.mockImplementationOnce(() => {
149+
return {
150+
[Symbol.asyncIterator]: () => ({
151+
next: vitest
152+
.fn()
153+
.mockResolvedValueOnce({
154+
done: false,
155+
value: { choices: [{ delta: { content: testContent } }] },
156+
})
157+
.mockResolvedValueOnce({ done: true }),
158+
}),
159+
}
160+
})
161+
162+
const stream = handler.createMessage("system prompt", [])
163+
const firstChunk = await stream.next()
164+
165+
expect(firstChunk.done).toBe(false)
166+
expect(firstChunk.value).toEqual({ type: "text", text: testContent })
167+
})
168+
169+
it("createMessage should yield usage data from stream", async () => {
170+
mockCreate.mockImplementationOnce(() => {
171+
return {
172+
[Symbol.asyncIterator]: () => ({
173+
next: vitest
174+
.fn()
175+
.mockResolvedValueOnce({
176+
done: false,
177+
value: {
178+
choices: [{ delta: {} }],
179+
usage: { prompt_tokens: 10, completion_tokens: 20 },
180+
},
181+
})
182+
.mockResolvedValueOnce({ done: true }),
183+
}),
184+
}
185+
})
186+
187+
const stream = handler.createMessage("system prompt", [])
188+
const firstChunk = await stream.next()
189+
190+
expect(firstChunk.done).toBe(false)
191+
expect(firstChunk.value).toEqual({ type: "usage", inputTokens: 10, outputTokens: 20 })
192+
})
193+
194+
it("createMessage should pass correct parameters to Z AI client", async () => {
195+
const modelId: InternationalZAiModelId = "glm-4.5"
196+
const modelInfo = internationalZAiModels[modelId]
197+
const handlerWithModel = new ZAiHandler({
198+
apiModelId: modelId,
199+
zaiApiKey: "test-zai-api-key",
200+
zaiApiLine: "international",
201+
})
202+
203+
mockCreate.mockImplementationOnce(() => {
204+
return {
205+
[Symbol.asyncIterator]: () => ({
206+
async next() {
207+
return { done: true }
208+
},
209+
}),
210+
}
211+
})
212+
213+
const systemPrompt = "Test system prompt for Z AI"
214+
const messages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Test message for Z AI" }]
215+
216+
const messageGenerator = handlerWithModel.createMessage(systemPrompt, messages)
217+
await messageGenerator.next()
218+
219+
expect(mockCreate).toHaveBeenCalledWith(
220+
expect.objectContaining({
221+
model: modelId,
222+
max_tokens: modelInfo.maxTokens,
223+
temperature: ZAI_DEFAULT_TEMPERATURE,
224+
messages: expect.arrayContaining([{ role: "system", content: systemPrompt }]),
225+
stream: true,
226+
stream_options: { include_usage: true },
227+
}),
228+
)
229+
})
230+
})
231+
})

0 commit comments

Comments
 (0)