diff --git a/packages/types/src/provider-settings.ts b/packages/types/src/provider-settings.ts index 4dfeacbf07..2a84d88d3e 100644 --- a/packages/types/src/provider-settings.ts +++ b/packages/types/src/provider-settings.ts @@ -168,6 +168,7 @@ const vertexSchema = apiModelIdProviderModelSchema.extend({ vertexJsonCredentials: z.string().optional(), vertexProjectId: z.string().optional(), vertexRegion: z.string().optional(), + vertexBaseUrl: z.string().optional(), enableUrlContext: z.boolean().optional(), enableGrounding: z.boolean().optional(), }) diff --git a/src/api/providers/__tests__/anthropic-vertex.spec.ts b/src/api/providers/__tests__/anthropic-vertex.spec.ts index 9d83f265c7..49fe41e872 100644 --- a/src/api/providers/__tests__/anthropic-vertex.spec.ts +++ b/src/api/providers/__tests__/anthropic-vertex.spec.ts @@ -809,4 +809,92 @@ describe("VertexHandler", () => { ) }) }) + + describe("custom base URL", () => { + it("should use custom base URL when provided with JSON credentials", () => { + const customBaseUrl = "https://custom-vertex-endpoint.example.com" + + const handler = new AnthropicVertexHandler({ + apiModelId: "claude-3-5-sonnet-v2@20241022", + vertexProjectId: "test-project", + vertexRegion: "us-central1", + vertexBaseUrl: customBaseUrl, + vertexJsonCredentials: JSON.stringify({ + type: "service_account", + project_id: "test-project", + private_key_id: "key-id", + private_key: "-----BEGIN PRIVATE KEY-----\ntest\n-----END PRIVATE KEY-----\n", + client_email: "test@test.iam.gserviceaccount.com", + client_id: "123456789", + auth_uri: "https://accounts.google.com/o/oauth2/auth", + token_uri: "https://oauth2.googleapis.com/token", + auth_provider_x509_cert_url: "https://www.googleapis.com/oauth2/v1/certs", + client_x509_cert_url: + "https://www.googleapis.com/robot/v1/metadata/x509/test%40test.iam.gserviceaccount.com", + }), + }) + + // Verify that AnthropicVertex was called with baseURL + expect(AnthropicVertex).toHaveBeenCalledWith( + expect.objectContaining({ + baseURL: customBaseUrl, + projectId: "test-project", + region: "us-central1", + }), + ) + }) + + it("should use custom base URL when provided with key file", () => { + const customBaseUrl = "https://custom-vertex-endpoint.example.com" + + const handler = new AnthropicVertexHandler({ + apiModelId: "claude-3-5-sonnet-v2@20241022", + vertexProjectId: "test-project", + vertexRegion: "us-central1", + vertexBaseUrl: customBaseUrl, + vertexKeyFile: "/path/to/keyfile.json", + }) + + // Verify that AnthropicVertex was called with baseURL + expect(AnthropicVertex).toHaveBeenCalledWith( + expect.objectContaining({ + baseURL: customBaseUrl, + projectId: "test-project", + region: "us-central1", + }), + ) + }) + + it("should use custom base URL when provided without credentials", () => { + const customBaseUrl = "https://custom-vertex-endpoint.example.com" + + const handler = new AnthropicVertexHandler({ + apiModelId: "claude-3-5-sonnet-v2@20241022", + vertexProjectId: "test-project", + vertexRegion: "us-central1", + vertexBaseUrl: customBaseUrl, + }) + + // Verify that AnthropicVertex was called with baseURL + expect(AnthropicVertex).toHaveBeenCalledWith( + expect.objectContaining({ + baseURL: customBaseUrl, + projectId: "test-project", + region: "us-central1", + }), + ) + }) + + it("should not include baseURL when no custom URL is provided", () => { + const handler = new AnthropicVertexHandler({ + apiModelId: "claude-3-5-sonnet-v2@20241022", + vertexProjectId: "test-project", + vertexRegion: "us-central1", + }) + + // Verify that AnthropicVertex was called without baseURL + const callArgs = (AnthropicVertex as any).mock.calls[0][0] + expect(callArgs.baseURL).toBeUndefined() + }) + }) }) diff --git a/src/api/providers/__tests__/vertex.spec.ts b/src/api/providers/__tests__/vertex.spec.ts index d147e79ba8..d140a58130 100644 --- a/src/api/providers/__tests__/vertex.spec.ts +++ b/src/api/providers/__tests__/vertex.spec.ts @@ -138,4 +138,54 @@ describe("VertexHandler", () => { expect(modelInfo.info.contextWindow).toBe(1048576) }) }) + + describe("custom base URL", () => { + it("should use custom base URL when provided", async () => { + const customBaseUrl = "https://custom-vertex-endpoint.example.com" + + handler = new VertexHandler({ + apiModelId: "gemini-1.5-pro-001", + vertexProjectId: "test-project", + vertexRegion: "us-central1", + vertexBaseUrl: customBaseUrl, + }) + + // Mock the generateContent method + const mockGenerateContent = vitest.fn().mockResolvedValue({ + text: "Test response with custom URL", + }) + handler["client"].models.generateContent = mockGenerateContent + + await handler.completePrompt("Test prompt") + + // Verify that the custom base URL was passed in the config + expect(mockGenerateContent).toHaveBeenCalledWith( + expect.objectContaining({ + config: expect.objectContaining({ + httpOptions: { baseUrl: customBaseUrl }, + }), + }), + ) + }) + + it("should not include httpOptions when no custom base URL is provided", async () => { + handler = new VertexHandler({ + apiModelId: "gemini-1.5-pro-001", + vertexProjectId: "test-project", + vertexRegion: "us-central1", + }) + + // Mock the generateContent method + const mockGenerateContent = vitest.fn().mockResolvedValue({ + text: "Test response without custom URL", + }) + handler["client"].models.generateContent = mockGenerateContent + + await handler.completePrompt("Test prompt") + + // Verify that httpOptions is undefined when no custom URL + const callArgs = mockGenerateContent.mock.calls[0][0] + expect(callArgs.config.httpOptions).toBeUndefined() + }) + }) }) diff --git a/src/api/providers/anthropic-vertex.ts b/src/api/providers/anthropic-vertex.ts index c70a15926d..94a4cd3d59 100644 --- a/src/api/providers/anthropic-vertex.ts +++ b/src/api/providers/anthropic-vertex.ts @@ -34,10 +34,26 @@ export class AnthropicVertexHandler extends BaseProvider implements SingleComple const projectId = this.options.vertexProjectId ?? "not-provided" const region = this.options.vertexRegion ?? "us-east5" + type VertexOptions = { + projectId: string + region: string + baseURL?: string + googleAuth?: GoogleAuth + } + + const baseOptions: VertexOptions = { + projectId, + region, + } + + // Add custom base URL if provided + if (this.options.vertexBaseUrl) { + baseOptions.baseURL = this.options.vertexBaseUrl + } + if (this.options.vertexJsonCredentials) { this.client = new AnthropicVertex({ - projectId, - region, + ...baseOptions, googleAuth: new GoogleAuth({ scopes: ["https://www.googleapis.com/auth/cloud-platform"], credentials: safeJsonParse(this.options.vertexJsonCredentials, undefined), @@ -45,15 +61,14 @@ export class AnthropicVertexHandler extends BaseProvider implements SingleComple }) } else if (this.options.vertexKeyFile) { this.client = new AnthropicVertex({ - projectId, - region, + ...baseOptions, googleAuth: new GoogleAuth({ scopes: ["https://www.googleapis.com/auth/cloud-platform"], keyFile: this.options.vertexKeyFile, }), }) } else { - this.client = new AnthropicVertex({ projectId, region }) + this.client = new AnthropicVertex(baseOptions) } } diff --git a/src/api/providers/gemini.ts b/src/api/providers/gemini.ts index 775d763a05..823510208d 100644 --- a/src/api/providers/gemini.ts +++ b/src/api/providers/gemini.ts @@ -27,6 +27,7 @@ type GeminiHandlerOptions = ApiHandlerOptions & { export class GeminiHandler extends BaseProvider implements SingleCompletionHandler { protected options: ApiHandlerOptions + private isVertex: boolean private client: GoogleGenAI @@ -34,6 +35,7 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl super() this.options = options + this.isVertex = isVertex ?? false const project = this.options.vertexProjectId ?? "not-provided" const location = this.options.vertexRegion ?? "not-provided" @@ -78,9 +80,12 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl tools.push({ googleSearch: {} }) } + // Use vertexBaseUrl if this is a Vertex handler, otherwise use googleGeminiBaseUrl + const baseUrl = this.isVertex ? this.options.vertexBaseUrl : this.options.googleGeminiBaseUrl + const config: GenerateContentConfig = { systemInstruction, - httpOptions: this.options.googleGeminiBaseUrl ? { baseUrl: this.options.googleGeminiBaseUrl } : undefined, + httpOptions: baseUrl ? { baseUrl } : undefined, thinkingConfig, maxOutputTokens: this.options.modelMaxTokens ?? maxTokens ?? undefined, temperature: this.options.modelTemperature ?? 0, @@ -220,10 +225,11 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl if (this.options.enableGrounding) { tools.push({ googleSearch: {} }) } + // Use vertexBaseUrl if this is a Vertex handler, otherwise use googleGeminiBaseUrl + const baseUrl = this.isVertex ? this.options.vertexBaseUrl : this.options.googleGeminiBaseUrl + const promptConfig: GenerateContentConfig = { - httpOptions: this.options.googleGeminiBaseUrl - ? { baseUrl: this.options.googleGeminiBaseUrl } - : undefined, + httpOptions: baseUrl ? { baseUrl } : undefined, temperature: this.options.modelTemperature ?? 0, ...(tools.length > 0 ? { tools } : {}), } diff --git a/webview-ui/src/components/settings/providers/Vertex.tsx b/webview-ui/src/components/settings/providers/Vertex.tsx index 1a57f4fa5e..df9fd09279 100644 --- a/webview-ui/src/components/settings/providers/Vertex.tsx +++ b/webview-ui/src/components/settings/providers/Vertex.tsx @@ -1,4 +1,4 @@ -import { useCallback } from "react" +import { useCallback, useState } from "react" import { Checkbox } from "vscrui" import { VSCodeLink, VSCodeTextField } from "@vscode/webview-ui-toolkit/react" @@ -18,6 +18,8 @@ type VertexProps = { export const Vertex = ({ apiConfiguration, setApiConfigurationField, fromWelcomeView }: VertexProps) => { const { t } = useAppTranslation() + const [vertexBaseUrlSelected, setVertexBaseUrlSelected] = useState(!!apiConfiguration?.vertexBaseUrl) + const handleInputChange = useCallback( ( field: K, @@ -94,6 +96,28 @@ export const Vertex = ({ apiConfiguration, setApiConfigurationField, fromWelcome +
+ { + setVertexBaseUrlSelected(checked) + if (!checked) { + setApiConfigurationField("vertexBaseUrl", "") + } + }}> + {t("settings:providers.useCustomBaseUrl")} + + {vertexBaseUrlSelected && ( + + )} +
+ {!fromWelcomeView && apiConfiguration.apiModelId?.startsWith("gemini") && (