diff --git a/packages/types/src/provider-settings.ts b/packages/types/src/provider-settings.ts index 4dfeacbf07..6d628ddfdf 100644 --- a/packages/types/src/provider-settings.ts +++ b/packages/types/src/provider-settings.ts @@ -308,9 +308,13 @@ const sambaNovaSchema = apiModelIdProviderModelSchema.extend({ sambaNovaApiKey: z.string().optional(), }) +export const zaiApiLineSchema = z.enum(["international_coding", "international", "china_coding", "china"]) + +export type ZaiApiLine = z.infer + const zaiSchema = apiModelIdProviderModelSchema.extend({ zaiApiKey: z.string().optional(), - zaiApiLine: z.union([z.literal("china"), z.literal("international")]).optional(), + zaiApiLine: zaiApiLineSchema.optional(), }) const fireworksSchema = apiModelIdProviderModelSchema.extend({ diff --git a/packages/types/src/providers/zai.ts b/packages/types/src/providers/zai.ts index f724744827..b3838c1406 100644 --- a/packages/types/src/providers/zai.ts +++ b/packages/types/src/providers/zai.ts @@ -1,4 +1,5 @@ import type { ModelInfo } from "../model.js" +import { ZaiApiLine } from "../provider-settings.js" // Z AI // https://docs.z.ai/guides/llm/glm-4.5 @@ -103,3 +104,14 @@ export const mainlandZAiModels = { } as const satisfies Record export const ZAI_DEFAULT_TEMPERATURE = 0 + +export const zaiApiLineConfigs = { + international_coding: { + name: "International Coding Plan", + baseUrl: "https://api.z.ai/api/coding/paas/v4", + isChina: false, + }, + international: { name: "International Standard", baseUrl: "https://api.z.ai/api/paas/v4", isChina: false }, + china_coding: { name: "China Coding Plan", baseUrl: "https://open.bigmodel.cn/api/coding/paas/v4", isChina: true }, + china: { name: "China Standard", baseUrl: "https://open.bigmodel.cn/api/paas/v4", isChina: true }, +} satisfies Record diff --git a/src/api/providers/__tests__/zai.spec.ts b/src/api/providers/__tests__/zai.spec.ts index a16aa9fcdf..7928a4298d 100644 --- a/src/api/providers/__tests__/zai.spec.ts +++ b/src/api/providers/__tests__/zai.spec.ts @@ -41,7 +41,11 @@ describe("ZAiHandler", () => { it("should use the correct international Z AI base URL", () => { new ZAiHandler({ zaiApiKey: "test-zai-api-key", zaiApiLine: "international" }) - expect(OpenAI).toHaveBeenCalledWith(expect.objectContaining({ baseURL: "https://api.z.ai/api/paas/v4" })) + expect(OpenAI).toHaveBeenCalledWith( + expect.objectContaining({ + baseURL: "https://api.z.ai/api/paas/v4", + }), + ) }) it("should use the provided API key for international", () => { @@ -109,7 +113,11 @@ describe("ZAiHandler", () => { describe("Default behavior", () => { it("should default to international when no zaiApiLine is specified", () => { const handlerDefault = new ZAiHandler({ zaiApiKey: "test-zai-api-key" }) - expect(OpenAI).toHaveBeenCalledWith(expect.objectContaining({ baseURL: "https://api.z.ai/api/paas/v4" })) + expect(OpenAI).toHaveBeenCalledWith( + expect.objectContaining({ + baseURL: "https://api.z.ai/api/coding/paas/v4", + }), + ) const model = handlerDefault.getModel() expect(model.id).toBe(internationalZAiDefaultModelId) diff --git a/src/api/providers/zai.ts b/src/api/providers/zai.ts index e37e37f01b..ce5aab9dd9 100644 --- a/src/api/providers/zai.ts +++ b/src/api/providers/zai.ts @@ -6,6 +6,7 @@ import { type InternationalZAiModelId, type MainlandZAiModelId, ZAI_DEFAULT_TEMPERATURE, + zaiApiLineConfigs, } from "@roo-code/types" import type { ApiHandlerOptions } from "../../shared/api" @@ -14,14 +15,14 @@ import { BaseOpenAiCompatibleProvider } from "./base-openai-compatible-provider" export class ZAiHandler extends BaseOpenAiCompatibleProvider { constructor(options: ApiHandlerOptions) { - const isChina = options.zaiApiLine === "china" + const isChina = zaiApiLineConfigs[options.zaiApiLine ?? "international_coding"].isChina const models = isChina ? mainlandZAiModels : internationalZAiModels const defaultModelId = isChina ? mainlandZAiDefaultModelId : internationalZAiDefaultModelId super({ ...options, providerName: "Z AI", - baseURL: isChina ? "https://open.bigmodel.cn/api/paas/v4" : "https://api.z.ai/api/paas/v4", + baseURL: zaiApiLineConfigs[options.zaiApiLine ?? "international_coding"].baseUrl, apiKey: options.zaiApiKey ?? "not-provided", defaultProviderModelId: defaultModelId, providerModels: models, diff --git a/webview-ui/src/components/settings/providers/ZAi.tsx b/webview-ui/src/components/settings/providers/ZAi.tsx index bc23f28346..c7f44510c1 100644 --- a/webview-ui/src/components/settings/providers/ZAi.tsx +++ b/webview-ui/src/components/settings/providers/ZAi.tsx @@ -1,7 +1,7 @@ import { useCallback } from "react" import { VSCodeTextField, VSCodeDropdown, VSCodeOption } from "@vscode/webview-ui-toolkit/react" -import type { ProviderSettings } from "@roo-code/types" +import { zaiApiLineConfigs, zaiApiLineSchema, type ProviderSettings } from "@roo-code/types" import { useAppTranslation } from "@src/i18n/TranslationContext" import { VSCodeButtonLink } from "@src/components/common/VSCodeButtonLink" @@ -33,15 +33,17 @@ export const ZAi = ({ apiConfiguration, setApiConfigurationField }: ZAiProps) =>
- - api.z.ai - - - open.bigmodel.cn - + {zaiApiLineSchema.options.map((zaiApiLine) => { + const config = zaiApiLineConfigs[zaiApiLine] + return ( + + {config.name} ({config.baseUrl}) + + ) + })}
{t("settings:providers.zaiEntrypointDescription")} @@ -62,7 +64,7 @@ export const ZAi = ({ apiConfiguration, setApiConfigurationField }: ZAiProps) => {!apiConfiguration?.zaiApiKey && (