Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions packages/types/src/provider-settings.ts
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ const vertexSchema = apiModelIdProviderModelSchema.extend({
vertexJsonCredentials: z.string().optional(),
vertexProjectId: z.string().optional(),
vertexRegion: z.string().optional(),
vertexBaseUrl: z.string().optional(),
enableUrlContext: z.boolean().optional(),
enableGrounding: z.boolean().optional(),
})
Expand Down
91 changes: 91 additions & 0 deletions src/api/providers/__tests__/anthropic-vertex.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -809,4 +809,95 @@ describe("VertexHandler", () => {
)
})
})

describe("custom base URL", () => {
it("should use custom base URL when provided with JSON credentials", () => {
const customBaseUrl = "https://custom-vertex-endpoint.example.com"

const handler = new AnthropicVertexHandler({
apiModelId: "claude-3-5-sonnet-v2@20241022",
vertexProjectId: "test-project",
vertexRegion: "us-central1",
vertexBaseUrl: customBaseUrl,
vertexJsonCredentials: JSON.stringify({
type: "service_account",
project_id: "test-project",
private_key_id: "key-id",
private_key: "-----BEGIN PRIVATE KEY-----\ntest\n-----END PRIVATE KEY-----\n",
client_email: "[email protected]",
client_id: "123456789",
auth_uri: "https://accounts.google.com/o/oauth2/auth",
token_uri: "https://oauth2.googleapis.com/token",
auth_provider_x509_cert_url: "https://www.googleapis.com/oauth2/v1/certs",
client_x509_cert_url:
"https://www.googleapis.com/robot/v1/metadata/x509/test%40test.iam.gserviceaccount.com",
}),
})

// Verify that AnthropicVertex was called with baseURL
expect(AnthropicVertex).toHaveBeenCalledWith(
expect.objectContaining({
baseURL: customBaseUrl,
projectId: "test-project",
region: "us-central1",
}),
)
})

it("should use custom base URL when provided with key file", () => {
const customBaseUrl = "https://custom-vertex-endpoint.example.com"

const handler = new AnthropicVertexHandler({
apiModelId: "claude-3-5-sonnet-v2@20241022",
vertexProjectId: "test-project",
vertexRegion: "us-central1",
vertexBaseUrl: customBaseUrl,
vertexKeyFile: "/path/to/keyfile.json",
})

// Verify that AnthropicVertex was called with baseURL
expect(AnthropicVertex).toHaveBeenCalledWith(
expect.objectContaining({
baseURL: customBaseUrl,
projectId: "test-project",
region: "us-central1",
}),
)
})

it("should use custom base URL when provided without credentials", () => {
const customBaseUrl = "https://custom-vertex-endpoint.example.com"

const handler = new AnthropicVertexHandler({
apiModelId: "claude-3-5-sonnet-v2@20241022",
vertexProjectId: "test-project",
vertexRegion: "us-central1",
vertexBaseUrl: customBaseUrl,
})

// Verify that AnthropicVertex was called with baseURL
expect(AnthropicVertex).toHaveBeenCalledWith(
expect.objectContaining({
baseURL: customBaseUrl,
projectId: "test-project",
region: "us-central1",
}),
)
})

it("should not include baseURL when no custom URL is provided", () => {
const handler = new AnthropicVertexHandler({
apiModelId: "claude-3-5-sonnet-v2@20241022",
vertexProjectId: "test-project",
vertexRegion: "us-central1",
})

// Verify that AnthropicVertex was called without baseURL
expect(AnthropicVertex).toHaveBeenCalledWith(
expect.not.objectContaining({
baseURL: expect.anything(),
}),
)
})
})
})
55 changes: 55 additions & 0 deletions src/api/providers/__tests__/vertex.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -138,4 +138,59 @@ describe("VertexHandler", () => {
expect(modelInfo.info.contextWindow).toBe(1048576)
})
})

describe("custom base URL", () => {
it("should use custom base URL when provided", async () => {
const customBaseUrl = "https://custom-vertex-endpoint.example.com"

handler = new VertexHandler({
apiModelId: "gemini-1.5-pro-001",
vertexProjectId: "test-project",
vertexRegion: "us-central1",
vertexBaseUrl: customBaseUrl,
})

// Mock the generateContent method
const mockGenerateContent = vitest.fn().mockResolvedValue({
text: "Test response with custom URL",
})
handler["client"].models.generateContent = mockGenerateContent

await handler.completePrompt("Test prompt")

// Verify that the custom base URL was passed in the config
expect(mockGenerateContent).toHaveBeenCalledWith(
expect.objectContaining({
config: expect.objectContaining({
httpOptions: { baseUrl: customBaseUrl },
}),
}),
)
})

it("should not include httpOptions when no custom base URL is provided", async () => {
handler = new VertexHandler({
apiModelId: "gemini-1.5-pro-001",
vertexProjectId: "test-project",
vertexRegion: "us-central1",
})

// Mock the generateContent method
const mockGenerateContent = vitest.fn().mockResolvedValue({
text: "Test response without custom URL",
})
handler["client"].models.generateContent = mockGenerateContent

await handler.completePrompt("Test prompt")

// Verify that httpOptions is undefined when no custom URL
expect(mockGenerateContent).toHaveBeenCalledWith(
expect.objectContaining({
config: expect.objectContaining({
httpOptions: undefined,
}),
}),
)
})
})
})
18 changes: 13 additions & 5 deletions src/api/providers/anthropic-vertex.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,26 +34,34 @@ export class AnthropicVertexHandler extends BaseProvider implements SingleComple
const projectId = this.options.vertexProjectId ?? "not-provided"
const region = this.options.vertexRegion ?? "us-east5"

const baseOptions: any = {
projectId,
region,
}

// Add custom base URL if provided
if (this.options.vertexBaseUrl) {
baseOptions.baseURL = this.options.vertexBaseUrl
}

if (this.options.vertexJsonCredentials) {
this.client = new AnthropicVertex({
projectId,
region,
...baseOptions,
googleAuth: new GoogleAuth({
scopes: ["https://www.googleapis.com/auth/cloud-platform"],
credentials: safeJsonParse<JWTInput>(this.options.vertexJsonCredentials, undefined),
}),
})
} else if (this.options.vertexKeyFile) {
this.client = new AnthropicVertex({
projectId,
region,
...baseOptions,
googleAuth: new GoogleAuth({
scopes: ["https://www.googleapis.com/auth/cloud-platform"],
keyFile: this.options.vertexKeyFile,
}),
})
} else {
this.client = new AnthropicVertex({ projectId, region })
this.client = new AnthropicVertex(baseOptions)
}
}

Expand Down
16 changes: 12 additions & 4 deletions src/api/providers/gemini.ts
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,13 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl
tools.push({ googleSearch: {} })
}

// Use vertexBaseUrl if this is a Vertex handler, otherwise use googleGeminiBaseUrl
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The same base URL selection logic is repeated in both createMessage and completePrompt. Consider refactoring this logic into a shared helper to reduce duplication.

const baseUrl =
this.constructor.name === "VertexHandler" ? this.options.vertexBaseUrl : this.options.googleGeminiBaseUrl
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using this.constructor.name === "VertexHandler" to detect a Vertex handler is brittle (e.g. it may break under minification). Consider using the provided isVertex flag (or another dedicated property) to decide which base URL to use.


const config: GenerateContentConfig = {
systemInstruction,
httpOptions: this.options.googleGeminiBaseUrl ? { baseUrl: this.options.googleGeminiBaseUrl } : undefined,
httpOptions: baseUrl ? { baseUrl } : undefined,
thinkingConfig,
maxOutputTokens: this.options.modelMaxTokens ?? maxTokens ?? undefined,
temperature: this.options.modelTemperature ?? 0,
Expand Down Expand Up @@ -220,10 +224,14 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl
if (this.options.enableGrounding) {
tools.push({ googleSearch: {} })
}
// Use vertexBaseUrl if this is a Vertex handler, otherwise use googleGeminiBaseUrl
const baseUrl =
this.constructor.name === "VertexHandler"
? this.options.vertexBaseUrl
: this.options.googleGeminiBaseUrl

const promptConfig: GenerateContentConfig = {
httpOptions: this.options.googleGeminiBaseUrl
? { baseUrl: this.options.googleGeminiBaseUrl }
: undefined,
httpOptions: baseUrl ? { baseUrl } : undefined,
temperature: this.options.modelTemperature ?? 0,
...(tools.length > 0 ? { tools } : {}),
}
Expand Down
26 changes: 25 additions & 1 deletion webview-ui/src/components/settings/providers/Vertex.tsx
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { useCallback } from "react"
import { useCallback, useState } from "react"
import { Checkbox } from "vscrui"
import { VSCodeLink, VSCodeTextField } from "@vscode/webview-ui-toolkit/react"

Expand All @@ -18,6 +18,8 @@ type VertexProps = {
export const Vertex = ({ apiConfiguration, setApiConfigurationField, fromWelcomeView }: VertexProps) => {
const { t } = useAppTranslation()

const [vertexBaseUrlSelected, setVertexBaseUrlSelected] = useState(!!apiConfiguration?.vertexBaseUrl)

const handleInputChange = useCallback(
<K extends keyof ProviderSettings, E>(
field: K,
Expand Down Expand Up @@ -94,6 +96,28 @@ export const Vertex = ({ apiConfiguration, setApiConfigurationField, fromWelcome
</Select>
</div>

<div className="mt-4">
<Checkbox
checked={vertexBaseUrlSelected}
onChange={(checked: boolean) => {
setVertexBaseUrlSelected(checked)
if (!checked) {
setApiConfigurationField("vertexBaseUrl", "")
}
}}>
{t("settings:providers.useCustomBaseUrl")}
</Checkbox>
{vertexBaseUrlSelected && (
<VSCodeTextField
value={apiConfiguration?.vertexBaseUrl || ""}
type="url"
onInput={handleInputChange("vertexBaseUrl")}
placeholder="https://us-central1-aiplatform.googleapis.com"
className="w-full mt-1"
/>
)}
</div>

{!fromWelcomeView && apiConfiguration.apiModelId?.startsWith("gemini") && (
<div className="mt-6">
<Checkbox
Expand Down
Loading