Skip to content

Commit a8a8311

Browse files
committed
feat: add custom base URL support for VertexAI provider
- Add vertexBaseUrl field to provider settings schema - Update VertexHandler to use custom base URL via GeminiHandler - Update AnthropicVertexHandler to use custom base URL - Add UI checkbox and input field for custom base URL in Vertex.tsx - Add comprehensive tests for the new functionality Fixes #7899
1 parent 8fee312 commit a8a8311

File tree

6 files changed

+197
-10
lines changed

6 files changed

+197
-10
lines changed

packages/types/src/provider-settings.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,7 @@ const vertexSchema = apiModelIdProviderModelSchema.extend({
168168
vertexJsonCredentials: z.string().optional(),
169169
vertexProjectId: z.string().optional(),
170170
vertexRegion: z.string().optional(),
171+
vertexBaseUrl: z.string().optional(),
171172
enableUrlContext: z.boolean().optional(),
172173
enableGrounding: z.boolean().optional(),
173174
})

src/api/providers/__tests__/anthropic-vertex.spec.ts

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -809,4 +809,95 @@ describe("VertexHandler", () => {
809809
)
810810
})
811811
})
812+
813+
describe("custom base URL", () => {
814+
it("should use custom base URL when provided with JSON credentials", () => {
815+
const customBaseUrl = "https://custom-vertex-endpoint.example.com"
816+
817+
const handler = new AnthropicVertexHandler({
818+
apiModelId: "claude-3-5-sonnet-v2@20241022",
819+
vertexProjectId: "test-project",
820+
vertexRegion: "us-central1",
821+
vertexBaseUrl: customBaseUrl,
822+
vertexJsonCredentials: JSON.stringify({
823+
type: "service_account",
824+
project_id: "test-project",
825+
private_key_id: "key-id",
826+
private_key: "-----BEGIN PRIVATE KEY-----\ntest\n-----END PRIVATE KEY-----\n",
827+
client_email: "[email protected]",
828+
client_id: "123456789",
829+
auth_uri: "https://accounts.google.com/o/oauth2/auth",
830+
token_uri: "https://oauth2.googleapis.com/token",
831+
auth_provider_x509_cert_url: "https://www.googleapis.com/oauth2/v1/certs",
832+
client_x509_cert_url:
833+
"https://www.googleapis.com/robot/v1/metadata/x509/test%40test.iam.gserviceaccount.com",
834+
}),
835+
})
836+
837+
// Verify that AnthropicVertex was called with baseURL
838+
expect(AnthropicVertex).toHaveBeenCalledWith(
839+
expect.objectContaining({
840+
baseURL: customBaseUrl,
841+
projectId: "test-project",
842+
region: "us-central1",
843+
}),
844+
)
845+
})
846+
847+
it("should use custom base URL when provided with key file", () => {
848+
const customBaseUrl = "https://custom-vertex-endpoint.example.com"
849+
850+
const handler = new AnthropicVertexHandler({
851+
apiModelId: "claude-3-5-sonnet-v2@20241022",
852+
vertexProjectId: "test-project",
853+
vertexRegion: "us-central1",
854+
vertexBaseUrl: customBaseUrl,
855+
vertexKeyFile: "/path/to/keyfile.json",
856+
})
857+
858+
// Verify that AnthropicVertex was called with baseURL
859+
expect(AnthropicVertex).toHaveBeenCalledWith(
860+
expect.objectContaining({
861+
baseURL: customBaseUrl,
862+
projectId: "test-project",
863+
region: "us-central1",
864+
}),
865+
)
866+
})
867+
868+
it("should use custom base URL when provided without credentials", () => {
869+
const customBaseUrl = "https://custom-vertex-endpoint.example.com"
870+
871+
const handler = new AnthropicVertexHandler({
872+
apiModelId: "claude-3-5-sonnet-v2@20241022",
873+
vertexProjectId: "test-project",
874+
vertexRegion: "us-central1",
875+
vertexBaseUrl: customBaseUrl,
876+
})
877+
878+
// Verify that AnthropicVertex was called with baseURL
879+
expect(AnthropicVertex).toHaveBeenCalledWith(
880+
expect.objectContaining({
881+
baseURL: customBaseUrl,
882+
projectId: "test-project",
883+
region: "us-central1",
884+
}),
885+
)
886+
})
887+
888+
it("should not include baseURL when no custom URL is provided", () => {
889+
const handler = new AnthropicVertexHandler({
890+
apiModelId: "claude-3-5-sonnet-v2@20241022",
891+
vertexProjectId: "test-project",
892+
vertexRegion: "us-central1",
893+
})
894+
895+
// Verify that AnthropicVertex was called without baseURL
896+
expect(AnthropicVertex).toHaveBeenCalledWith(
897+
expect.not.objectContaining({
898+
baseURL: expect.anything(),
899+
}),
900+
)
901+
})
902+
})
812903
})

src/api/providers/__tests__/vertex.spec.ts

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,4 +138,59 @@ describe("VertexHandler", () => {
138138
expect(modelInfo.info.contextWindow).toBe(1048576)
139139
})
140140
})
141+
142+
describe("custom base URL", () => {
143+
it("should use custom base URL when provided", async () => {
144+
const customBaseUrl = "https://custom-vertex-endpoint.example.com"
145+
146+
handler = new VertexHandler({
147+
apiModelId: "gemini-1.5-pro-001",
148+
vertexProjectId: "test-project",
149+
vertexRegion: "us-central1",
150+
vertexBaseUrl: customBaseUrl,
151+
})
152+
153+
// Mock the generateContent method
154+
const mockGenerateContent = vitest.fn().mockResolvedValue({
155+
text: "Test response with custom URL",
156+
})
157+
handler["client"].models.generateContent = mockGenerateContent
158+
159+
await handler.completePrompt("Test prompt")
160+
161+
// Verify that the custom base URL was passed in the config
162+
expect(mockGenerateContent).toHaveBeenCalledWith(
163+
expect.objectContaining({
164+
config: expect.objectContaining({
165+
httpOptions: { baseUrl: customBaseUrl },
166+
}),
167+
}),
168+
)
169+
})
170+
171+
it("should not include httpOptions when no custom base URL is provided", async () => {
172+
handler = new VertexHandler({
173+
apiModelId: "gemini-1.5-pro-001",
174+
vertexProjectId: "test-project",
175+
vertexRegion: "us-central1",
176+
})
177+
178+
// Mock the generateContent method
179+
const mockGenerateContent = vitest.fn().mockResolvedValue({
180+
text: "Test response without custom URL",
181+
})
182+
handler["client"].models.generateContent = mockGenerateContent
183+
184+
await handler.completePrompt("Test prompt")
185+
186+
// Verify that httpOptions is undefined when no custom URL
187+
expect(mockGenerateContent).toHaveBeenCalledWith(
188+
expect.objectContaining({
189+
config: expect.objectContaining({
190+
httpOptions: undefined,
191+
}),
192+
}),
193+
)
194+
})
195+
})
141196
})

src/api/providers/anthropic-vertex.ts

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,26 +34,34 @@ export class AnthropicVertexHandler extends BaseProvider implements SingleComple
3434
const projectId = this.options.vertexProjectId ?? "not-provided"
3535
const region = this.options.vertexRegion ?? "us-east5"
3636

37+
const baseOptions: any = {
38+
projectId,
39+
region,
40+
}
41+
42+
// Add custom base URL if provided
43+
if (this.options.vertexBaseUrl) {
44+
baseOptions.baseURL = this.options.vertexBaseUrl
45+
}
46+
3747
if (this.options.vertexJsonCredentials) {
3848
this.client = new AnthropicVertex({
39-
projectId,
40-
region,
49+
...baseOptions,
4150
googleAuth: new GoogleAuth({
4251
scopes: ["https://www.googleapis.com/auth/cloud-platform"],
4352
credentials: safeJsonParse<JWTInput>(this.options.vertexJsonCredentials, undefined),
4453
}),
4554
})
4655
} else if (this.options.vertexKeyFile) {
4756
this.client = new AnthropicVertex({
48-
projectId,
49-
region,
57+
...baseOptions,
5058
googleAuth: new GoogleAuth({
5159
scopes: ["https://www.googleapis.com/auth/cloud-platform"],
5260
keyFile: this.options.vertexKeyFile,
5361
}),
5462
})
5563
} else {
56-
this.client = new AnthropicVertex({ projectId, region })
64+
this.client = new AnthropicVertex(baseOptions)
5765
}
5866
}
5967

src/api/providers/gemini.ts

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,13 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl
7878
tools.push({ googleSearch: {} })
7979
}
8080

81+
// Use vertexBaseUrl if this is a Vertex handler, otherwise use googleGeminiBaseUrl
82+
const baseUrl =
83+
this.constructor.name === "VertexHandler" ? this.options.vertexBaseUrl : this.options.googleGeminiBaseUrl
84+
8185
const config: GenerateContentConfig = {
8286
systemInstruction,
83-
httpOptions: this.options.googleGeminiBaseUrl ? { baseUrl: this.options.googleGeminiBaseUrl } : undefined,
87+
httpOptions: baseUrl ? { baseUrl } : undefined,
8488
thinkingConfig,
8589
maxOutputTokens: this.options.modelMaxTokens ?? maxTokens ?? undefined,
8690
temperature: this.options.modelTemperature ?? 0,
@@ -220,10 +224,14 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl
220224
if (this.options.enableGrounding) {
221225
tools.push({ googleSearch: {} })
222226
}
227+
// Use vertexBaseUrl if this is a Vertex handler, otherwise use googleGeminiBaseUrl
228+
const baseUrl =
229+
this.constructor.name === "VertexHandler"
230+
? this.options.vertexBaseUrl
231+
: this.options.googleGeminiBaseUrl
232+
223233
const promptConfig: GenerateContentConfig = {
224-
httpOptions: this.options.googleGeminiBaseUrl
225-
? { baseUrl: this.options.googleGeminiBaseUrl }
226-
: undefined,
234+
httpOptions: baseUrl ? { baseUrl } : undefined,
227235
temperature: this.options.modelTemperature ?? 0,
228236
...(tools.length > 0 ? { tools } : {}),
229237
}

webview-ui/src/components/settings/providers/Vertex.tsx

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import { useCallback } from "react"
1+
import { useCallback, useState } from "react"
22
import { Checkbox } from "vscrui"
33
import { VSCodeLink, VSCodeTextField } from "@vscode/webview-ui-toolkit/react"
44

@@ -18,6 +18,8 @@ type VertexProps = {
1818
export const Vertex = ({ apiConfiguration, setApiConfigurationField, fromWelcomeView }: VertexProps) => {
1919
const { t } = useAppTranslation()
2020

21+
const [vertexBaseUrlSelected, setVertexBaseUrlSelected] = useState(!!apiConfiguration?.vertexBaseUrl)
22+
2123
const handleInputChange = useCallback(
2224
<K extends keyof ProviderSettings, E>(
2325
field: K,
@@ -94,6 +96,28 @@ export const Vertex = ({ apiConfiguration, setApiConfigurationField, fromWelcome
9496
</Select>
9597
</div>
9698

99+
<div className="mt-4">
100+
<Checkbox
101+
checked={vertexBaseUrlSelected}
102+
onChange={(checked: boolean) => {
103+
setVertexBaseUrlSelected(checked)
104+
if (!checked) {
105+
setApiConfigurationField("vertexBaseUrl", "")
106+
}
107+
}}>
108+
{t("settings:providers.useCustomBaseUrl")}
109+
</Checkbox>
110+
{vertexBaseUrlSelected && (
111+
<VSCodeTextField
112+
value={apiConfiguration?.vertexBaseUrl || ""}
113+
type="url"
114+
onInput={handleInputChange("vertexBaseUrl")}
115+
placeholder="https://us-central1-aiplatform.googleapis.com"
116+
className="w-full mt-1"
117+
/>
118+
)}
119+
</div>
120+
97121
{!fromWelcomeView && apiConfiguration.apiModelId?.startsWith("gemini") && (
98122
<div className="mt-6">
99123
<Checkbox

0 commit comments

Comments
 (0)